Skip to content
Snippets Groups Projects
Commit 9c2a9c57 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add transform_dict methods to plot_walkers

parent ecf6c1c2
No related branches found
No related tags found
No related merge requests found
...@@ -502,7 +502,6 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -502,7 +502,6 @@ class MCMCSearch(core.BaseSearchClass):
sampler = self._run_sampler(sampler, p0, nburn=n, window=window) sampler = self._run_sampler(sampler, p0, nburn=n, window=window)
if create_plots: if create_plots:
fig, axes = self._plot_walkers(sampler, fig, axes = self._plot_walkers(sampler,
symbols=self.theta_symbols,
**kwargs) **kwargs)
fig.tight_layout() fig.tight_layout()
fig.savefig('{}/{}_init_{}_walkers.png'.format( fig.savefig('{}/{}_init_{}_walkers.png'.format(
...@@ -522,8 +521,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -522,8 +521,7 @@ class MCMCSearch(core.BaseSearchClass):
nburn+nprod)) nburn+nprod))
sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod) sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
if create_plots: if create_plots:
fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols, fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs)
nprod=nprod, **kwargs)
fig.tight_layout() fig.tight_layout()
fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
) )
...@@ -603,7 +601,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -603,7 +601,7 @@ class MCMCSearch(core.BaseSearchClass):
return samples return samples
def _get_labels(self): def _get_labels(self, newline_units=False):
""" Combine the units, symbols and rescaling to give labels """ """ Combine the units, symbols and rescaling to give labels """
labels = [] labels = []
...@@ -620,7 +618,10 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -620,7 +618,10 @@ class MCMCSearch(core.BaseSearchClass):
if 'unit' in self.transform_dictionary[key]: if 'unit' in self.transform_dictionary[key]:
u = self.transform_dictionary[key]['unit'] u = self.transform_dictionary[key]['unit']
if label is None: if label is None:
if newline_units:
label = '{} \n [{}]'.format(s, u) label = '{} \n [{}]'.format(s, u)
else:
label = '{} [{}]'.format(s, u)
labels.append(label) labels.append(label)
return labels return labels
...@@ -694,7 +695,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -694,7 +695,7 @@ class MCMCSearch(core.BaseSearchClass):
fig, axes = fig_and_axes fig, axes = fig_and_axes
samples_plt = copy.copy(self.samples) samples_plt = copy.copy(self.samples)
labels = self._get_labels() labels = self._get_labels(newline_units=True)
samples_plt = self._scale_samples(samples_plt, self.theta_keys) samples_plt = self._scale_samples(samples_plt, self.theta_keys)
...@@ -963,9 +964,11 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -963,9 +964,11 @@ class MCMCSearch(core.BaseSearchClass):
def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k", def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k",
temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False, temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
fig=None, axes=None, xoffset=0, plot_det_stat=False, fig=None, axes=None, xoffset=0, plot_det_stat=False,
context='ggplot', subtractions=None, labelpad=0.05): context='ggplot', labelpad=5):
""" Plot all the chains from a sampler """ """ Plot all the chains from a sampler """
if symbols is None:
symbols = self._get_labels()
if context not in plt.style.available: if context not in plt.style.available:
raise ValueError(( raise ValueError((
'The requested context {} is not available; please select a' 'The requested context {} is not available; please select a'
...@@ -977,7 +980,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -977,7 +980,7 @@ class MCMCSearch(core.BaseSearchClass):
shape = sampler.chain.shape shape = sampler.chain.shape
if len(shape) == 3: if len(shape) == 3:
nwalkers, nsteps, ndim = shape nwalkers, nsteps, ndim = shape
chain = sampler.chain[:, :, :] chain = sampler.chain[:, :, :].copy()
if len(shape) == 4: if len(shape) == 4:
ntemps, nwalkers, nsteps, ndim = shape ntemps, nwalkers, nsteps, ndim = shape
if temp < ntemps: if temp < ntemps:
...@@ -985,13 +988,11 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -985,13 +988,11 @@ class MCMCSearch(core.BaseSearchClass):
else: else:
raise ValueError(("Requested temperature {} outside of" raise ValueError(("Requested temperature {} outside of"
"available range").format(temp)) "available range").format(temp))
chain = sampler.chain[temp, :, :, :] chain = sampler.chain[temp, :, :, :].copy()
if subtractions is None: samples = chain.reshape((nwalkers*nsteps, ndim))
subtractions = [0 for i in range(ndim)] samples = self._scale_samples(samples, self.theta_keys)
else: chain = chain.reshape((nwalkers, nsteps, ndim))
if len(subtractions) != self.ndim:
raise ValueError('subtractions must be of length ndim')
if plot_det_stat: if plot_det_stat:
extra_subplots = 1 extra_subplots = 1
...@@ -1017,23 +1018,24 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -1017,23 +1018,24 @@ class MCMCSearch(core.BaseSearchClass):
cs = chain[:, :, i].T cs = chain[:, :, i].T
if burnin_idx > 0: if burnin_idx > 0:
axes[i].plot(xoffset+idxs[:last_idx+1], axes[i].plot(xoffset+idxs[:last_idx+1],
cs[:last_idx+1]-subtractions[i], cs[:last_idx+1],
color="C3", alpha=alpha, color="C3", alpha=alpha,
lw=lw) lw=lw)
axes[i].axvline(xoffset+last_idx, axes[i].axvline(xoffset+last_idx,
color='k', ls='--', lw=0.5) color='k', ls='--', lw=0.5)
axes[i].plot(xoffset+idxs[burnin_idx:], axes[i].plot(xoffset+idxs[burnin_idx:],
cs[burnin_idx:]-subtractions[i], cs[burnin_idx:],
color="k", alpha=alpha, lw=lw) color="k", alpha=alpha, lw=lw)
axes[i].set_xlim(0, xoffset+idxs[-1]) axes[i].set_xlim(0, xoffset+idxs[-1])
if symbols: if symbols:
if subtractions[i] == 0:
axes[i].set_ylabel(symbols[i], labelpad=labelpad) axes[i].set_ylabel(symbols[i], labelpad=labelpad)
else: #if subtractions[i] == 0:
axes[i].set_ylabel( # axes[i].set_ylabel(symbols[i], labelpad=labelpad)
symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$', #else:
labelpad=labelpad) # axes[i].set_ylabel(
# symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
# labelpad=labelpad)
# if hasattr(self, 'convergence_diagnostic'): # if hasattr(self, 'convergence_diagnostic'):
# ax = axes[i].twinx() # ax = axes[i].twinx()
...@@ -2120,7 +2122,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): ...@@ -2120,7 +2122,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
if create_plots: if create_plots:
fig, axes = self._plot_walkers( fig, axes = self._plot_walkers(
sampler, symbols=self.theta_symbols, fig=fig, axes=axes, sampler, fig=fig, axes=axes,
nprod=nprod, xoffset=nsteps_total, **kwargs) nprod=nprod, xoffset=nsteps_total, **kwargs)
for ax in axes[:self.ndim]: for ax in axes[:self.ndim]:
ax.axvline(nsteps_total, color='k', ls='--', lw=0.25) ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment