diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index 86e04b1be73a2f9dcbe31565c662dc0300248a69..912445e33f2e2846710b8edaf70df7490f262144 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -502,7 +502,6 @@ class MCMCSearch(core.BaseSearchClass):
             sampler = self._run_sampler(sampler, p0, nburn=n, window=window)
             if create_plots:
                 fig, axes = self._plot_walkers(sampler,
-                                               symbols=self.theta_symbols,
                                                **kwargs)
                 fig.tight_layout()
                 fig.savefig('{}/{}_init_{}_walkers.png'.format(
@@ -522,8 +521,7 @@ class MCMCSearch(core.BaseSearchClass):
             nburn+nprod))
         sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
         if create_plots:
-            fig, axes = self._plot_walkers(sampler, symbols=self.theta_symbols,
-                                           nprod=nprod, **kwargs)
+            fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs)
             fig.tight_layout()
             fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label),
                         )
@@ -603,7 +601,7 @@ class MCMCSearch(core.BaseSearchClass):
 
         return samples
 
-    def _get_labels(self):
+    def _get_labels(self, newline_units=False):
         """ Combine the units, symbols and rescaling to give labels """
 
         labels = []
@@ -620,7 +618,10 @@ class MCMCSearch(core.BaseSearchClass):
                 if 'unit' in self.transform_dictionary[key]:
                     u = self.transform_dictionary[key]['unit']
             if label is None:
-                label = '{} \n [{}]'.format(s, u)
+                if newline_units:
+                    label = '{} \n [{}]'.format(s, u)
+                else:
+                    label = '{} [{}]'.format(s, u)
             labels.append(label)
         return labels
 
@@ -694,7 +695,7 @@ class MCMCSearch(core.BaseSearchClass):
                 fig, axes = fig_and_axes
 
             samples_plt = copy.copy(self.samples)
-            labels = self._get_labels()
+            labels = self._get_labels(newline_units=True)
 
             samples_plt = self._scale_samples(samples_plt, self.theta_keys)
 
@@ -963,9 +964,11 @@ class MCMCSearch(core.BaseSearchClass):
     def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k",
                       temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
                       fig=None, axes=None, xoffset=0, plot_det_stat=False,
-                      context='ggplot', subtractions=None, labelpad=0.05):
+                      context='ggplot', labelpad=5):
         """ Plot all the chains from a sampler """
 
+        if symbols is None:
+            symbols = self._get_labels()
         if context not in plt.style.available:
             raise ValueError((
                 'The requested context {} is not available; please select a'
@@ -977,7 +980,7 @@ class MCMCSearch(core.BaseSearchClass):
         shape = sampler.chain.shape
         if len(shape) == 3:
             nwalkers, nsteps, ndim = shape
-            chain = sampler.chain[:, :, :]
+            chain = sampler.chain[:, :, :].copy()
         if len(shape) == 4:
             ntemps, nwalkers, nsteps, ndim = shape
             if temp < ntemps:
@@ -985,13 +988,11 @@ class MCMCSearch(core.BaseSearchClass):
             else:
                 raise ValueError(("Requested temperature {} outside of"
                                   "available range").format(temp))
-            chain = sampler.chain[temp, :, :, :]
+            chain = sampler.chain[temp, :, :, :].copy()
 
-        if subtractions is None:
-            subtractions = [0 for i in range(ndim)]
-        else:
-            if len(subtractions) != self.ndim:
-                raise ValueError('subtractions must be of length ndim')
+        samples = chain.reshape((nwalkers*nsteps, ndim))
+        samples = self._scale_samples(samples, self.theta_keys)
+        chain = chain.reshape((nwalkers, nsteps, ndim))
 
         if plot_det_stat:
             extra_subplots = 1
@@ -1017,23 +1018,24 @@ class MCMCSearch(core.BaseSearchClass):
                     cs = chain[:, :, i].T
                     if burnin_idx > 0:
                         axes[i].plot(xoffset+idxs[:last_idx+1],
-                                     cs[:last_idx+1]-subtractions[i],
+                                     cs[:last_idx+1],
                                      color="C3", alpha=alpha,
                                      lw=lw)
                         axes[i].axvline(xoffset+last_idx,
                                         color='k', ls='--', lw=0.5)
                     axes[i].plot(xoffset+idxs[burnin_idx:],
-                                 cs[burnin_idx:]-subtractions[i],
+                                 cs[burnin_idx:],
                                  color="k", alpha=alpha, lw=lw)
 
                     axes[i].set_xlim(0, xoffset+idxs[-1])
                     if symbols:
-                        if subtractions[i] == 0:
-                            axes[i].set_ylabel(symbols[i], labelpad=labelpad)
-                        else:
-                            axes[i].set_ylabel(
-                                symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
-                                labelpad=labelpad)
+                        axes[i].set_ylabel(symbols[i], labelpad=labelpad)
+                        #if subtractions[i] == 0:
+                        #    axes[i].set_ylabel(symbols[i], labelpad=labelpad)
+                        #else:
+                        #    axes[i].set_ylabel(
+                        #        symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
+                        #        labelpad=labelpad)
 
 #                    if hasattr(self, 'convergence_diagnostic'):
 #                        ax = axes[i].twinx()
@@ -2120,7 +2122,7 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
 
             if create_plots:
                 fig, axes = self._plot_walkers(
-                    sampler, symbols=self.theta_symbols, fig=fig, axes=axes,
+                    sampler, fig=fig, axes=axes,
                     nprod=nprod, xoffset=nsteps_total, **kwargs)
                 for ax in axes[:self.ndim]:
                     ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)