diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index 86e04b1be73a2f9dcbe31565c662dc0300248a69..912445e33f2e2846710b8edaf70df7490f262144 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -502,7 +502,6 @@ class MCMCSearch(core.BaseSearchClass): sampler = self._run_sampler(sampler, p0, nburn=n, window=window) if create_plots: fig, axes = self._plot_walkers(sampler, - symbols=self.theta_symbols, **kwargs) fig.tight_layout() fig.savefig('{}/{}_init_{}_walkers.png'.format( @@ -522,8 +521,7 @@ class MCMCSearch(core.BaseSearchClass): nburn+nprod)) sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod) if create_plots: - fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols, - nprod=nprod, **kwargs) + fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs) fig.tight_layout() fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label), ) @@ -603,7 +601,7 @@ class MCMCSearch(core.BaseSearchClass): return samples - def _get_labels(self): + def _get_labels(self, newline_units=False): """ Combine the units, symbols and rescaling to give labels """ labels = [] @@ -620,7 +618,10 @@ class MCMCSearch(core.BaseSearchClass): if 'unit' in self.transform_dictionary[key]: u = self.transform_dictionary[key]['unit'] if label is None: - label = '{} \n [{}]'.format(s, u) + if newline_units: + label = '{} \n [{}]'.format(s, u) + else: + label = '{} [{}]'.format(s, u) labels.append(label) return labels @@ -694,7 +695,7 @@ class MCMCSearch(core.BaseSearchClass): fig, axes = fig_and_axes samples_plt = copy.copy(self.samples) - labels = self._get_labels() + labels = self._get_labels(newline_units=True) samples_plt = self._scale_samples(samples_plt, self.theta_keys) @@ -963,9 +964,11 @@ class MCMCSearch(core.BaseSearchClass): def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k", temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False, fig=None, axes=None, xoffset=0, plot_det_stat=False, - context='ggplot', subtractions=None, labelpad=0.05): + context='ggplot', labelpad=5): """ Plot all the chains from a sampler """ + if symbols is None: + symbols = self._get_labels() if context not in plt.style.available: raise ValueError(( 'The requested context {} is not available; please select a' @@ -977,7 +980,7 @@ class MCMCSearch(core.BaseSearchClass): shape = sampler.chain.shape if len(shape) == 3: nwalkers, nsteps, ndim = shape - chain = sampler.chain[:, :, :] + chain = sampler.chain[:, :, :].copy() if len(shape) == 4: ntemps, nwalkers, nsteps, ndim = shape if temp < ntemps: @@ -985,13 +988,11 @@ class MCMCSearch(core.BaseSearchClass): else: raise ValueError(("Requested temperature {} outside of" "available range").format(temp)) - chain = sampler.chain[temp, :, :, :] + chain = sampler.chain[temp, :, :, :].copy() - if subtractions is None: - subtractions = [0 for i in range(ndim)] - else: - if len(subtractions) != self.ndim: - raise ValueError('subtractions must be of length ndim') + samples = chain.reshape((nwalkers*nsteps, ndim)) + samples = self._scale_samples(samples, self.theta_keys) + chain = chain.reshape((nwalkers, nsteps, ndim)) if plot_det_stat: extra_subplots = 1 @@ -1017,23 +1018,24 @@ class MCMCSearch(core.BaseSearchClass): cs = chain[:, :, i].T if burnin_idx > 0: axes[i].plot(xoffset+idxs[:last_idx+1], - cs[:last_idx+1]-subtractions[i], + cs[:last_idx+1], color="C3", alpha=alpha, lw=lw) axes[i].axvline(xoffset+last_idx, color='k', ls='--', lw=0.5) axes[i].plot(xoffset+idxs[burnin_idx:], - cs[burnin_idx:]-subtractions[i], + cs[burnin_idx:], color="k", alpha=alpha, lw=lw) axes[i].set_xlim(0, xoffset+idxs[-1]) if symbols: - if subtractions[i] == 0: - axes[i].set_ylabel(symbols[i], labelpad=labelpad) - else: - axes[i].set_ylabel( - symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$', - labelpad=labelpad) + axes[i].set_ylabel(symbols[i], labelpad=labelpad) + #if subtractions[i] == 0: + # axes[i].set_ylabel(symbols[i], labelpad=labelpad) + #else: + # axes[i].set_ylabel( + # symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$', + # labelpad=labelpad) # if hasattr(self, 'convergence_diagnostic'): # ax = axes[i].twinx() @@ -2120,7 +2122,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch): if create_plots: fig, axes = self._plot_walkers( - sampler, symbols=self.theta_symbols, fig=fig, axes=axes, + sampler, fig=fig, axes=axes, nprod=nprod, xoffset=nsteps_total, **kwargs) for ax in axes[:self.ndim]: ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)