Commit 9c2a9c57 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add transform_dict methods to plot_walkers

parent ecf6c1c2
......@@ -502,7 +502,6 @@ class MCMCSearch(core.BaseSearchClass):
sampler = self._run_sampler(sampler, p0, nburn=n, window=window)
if create_plots:
fig, axes = self._plot_walkers(sampler,
symbols=self.theta_symbols,
**kwargs)
fig.tight_layout()
fig.savefig('{}/{}_init_{}_walkers.png'.format(
......@@ -522,8 +521,7 @@ class MCMCSearch(core.BaseSearchClass):
nburn+nprod))
sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
if create_plots:
fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols,
nprod=nprod, **kwargs)
fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs)
fig.tight_layout()
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
)
......@@ -603,7 +601,7 @@ class MCMCSearch(core.BaseSearchClass):
return samples
def _get_labels(self):
def _get_labels(self, newline_units=False):
""" Combine the units, symbols and rescaling to give labels """
labels = []
......@@ -620,7 +618,10 @@ class MCMCSearch(core.BaseSearchClass):
if 'unit' in self.transform_dictionary[key]:
u = self.transform_dictionary[key]['unit']
if label is None:
label = '{} \n [{}]'.format(s, u)
if newline_units:
label = '{} \n [{}]'.format(s, u)
else:
label = '{} [{}]'.format(s, u)
labels.append(label)
return labels
......@@ -694,7 +695,7 @@ class MCMCSearch(core.BaseSearchClass):
fig, axes = fig_and_axes
samples_plt = copy.copy(self.samples)
labels = self._get_labels()
labels = self._get_labels(newline_units=True)
samples_plt = self._scale_samples(samples_plt, self.theta_keys)
......@@ -963,9 +964,11 @@ class MCMCSearch(core.BaseSearchClass):
def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k",
temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=False,
context='ggplot', subtractions=None, labelpad=0.05):
context='ggplot', labelpad=5):
""" Plot all the chains from a sampler """
if symbols is None:
symbols = self._get_labels()
if context not in plt.style.available:
raise ValueError((
'The requested context {} is not available; please select a'
......@@ -977,7 +980,7 @@ class MCMCSearch(core.BaseSearchClass):
shape = sampler.chain.shape
if len(shape) == 3:
nwalkers, nsteps, ndim = shape
chain = sampler.chain[:, :, :]
chain = sampler.chain[:, :, :].copy()
if len(shape) == 4:
ntemps, nwalkers, nsteps, ndim = shape
if temp < ntemps:
......@@ -985,13 +988,11 @@ class MCMCSearch(core.BaseSearchClass):
else:
raise ValueError(("Requested temperature {} outside of"
"available range").format(temp))
chain = sampler.chain[temp, :, :, :]
chain = sampler.chain[temp, :, :, :].copy()
if subtractions is None:
subtractions = [0 for i in range(ndim)]
else:
if len(subtractions) != self.ndim:
raise ValueError('subtractions must be of length ndim')
samples = chain.reshape((nwalkers*nsteps, ndim))
samples = self._scale_samples(samples, self.theta_keys)
chain = chain.reshape((nwalkers, nsteps, ndim))
if plot_det_stat:
extra_subplots = 1
......@@ -1017,23 +1018,24 @@ class MCMCSearch(core.BaseSearchClass):
cs = chain[:, :, i].T
if burnin_idx > 0:
axes[i].plot(xoffset+idxs[:last_idx+1],
cs[:last_idx+1]-subtractions[i],
cs[:last_idx+1],
color="C3", alpha=alpha,
lw=lw)
axes[i].axvline(xoffset+last_idx,
color='k', ls='--', lw=0.5)
axes[i].plot(xoffset+idxs[burnin_idx:],
cs[burnin_idx:]-subtractions[i],
cs[burnin_idx:],
color="k", alpha=alpha, lw=lw)
axes[i].set_xlim(0, xoffset+idxs[-1])
if symbols:
if subtractions[i] == 0:
axes[i].set_ylabel(symbols[i], labelpad=labelpad)
else:
axes[i].set_ylabel(
symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
labelpad=labelpad)
axes[i].set_ylabel(symbols[i], labelpad=labelpad)
#if subtractions[i] == 0:
# axes[i].set_ylabel(symbols[i], labelpad=labelpad)
#else:
# axes[i].set_ylabel(
# symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
# labelpad=labelpad)
# if hasattr(self, 'convergence_diagnostic'):
# ax = axes[i].twinx()
......@@ -2120,7 +2122,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
if create_plots:
fig, axes = self._plot_walkers(
sampler, symbols=self.theta_symbols, fig=fig, axes=axes,
sampler, fig=fig, axes=axes,
nprod=nprod, xoffset=nsteps_total, **kwargs)
for ax in axes[:self.ndim]:
ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment