diff --git a/pyfstat.py b/pyfstat.py index 4509aa1c613cbaa5e8824f7a2f1b08d4f192768a..bfec8ba06c38804a8e32bee85b26620a0f3d1b23 100755 --- a/pyfstat.py +++ b/pyfstat.py @@ -643,64 +643,66 @@ class MCMCSearch(BaseSearchClass): self.save_data(sampler, samples, lnprobs, lnlikes) def plot_corner(self, figsize=(7, 7), tglitch_ratio=False, - add_prior=False, nstds=None, label_offset=0.4, **kwargs): - - fig, axes = plt.subplots(self.ndim, self.ndim, - figsize=figsize) - - samples_plt = copy.copy(self.samples) - theta_symbols_plt = copy.copy(self.theta_symbols) - theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') for s - in theta_symbols_plt] - - if tglitch_ratio: - for j, k in enumerate(self.theta_keys): - if k == 'tglitch': - s = samples_plt[:, j] - samples_plt[:, j] = (s - self.tstart)/( - self.tend - self.tstart) - theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$' - - if type(nstds) is int and 'range' not in kwargs: - _range = [] - for j, s in enumerate(samples_plt.T): - median = np.median(s) - std = np.std(s) - _range.append((median - nstds*std, median + nstds*std)) - else: - _range = None - - fig_triangle = corner.corner(samples_plt, - labels=theta_symbols_plt, - fig=fig, - bins=50, - max_n_ticks=4, - plot_contours=True, - plot_datapoints=True, - label_kwargs={'fontsize': 8}, - data_kwargs={'alpha': 0.1, - 'ms': 0.5}, - range=_range, - **kwargs) - - axes_list = fig_triangle.get_axes() - axes = np.array(axes_list).reshape(self.ndim, self.ndim) - plt.draw() - for ax in axes[:, 0]: - ax.yaxis.set_label_coords(-label_offset, 0.5) - for ax in axes[-1, :]: - ax.xaxis.set_label_coords(0.5, -label_offset) - for ax in axes_list: - ax.set_rasterized(True) - ax.set_rasterization_zorder(-10) - plt.tight_layout(h_pad=0.0, w_pad=0.0) - fig.subplots_adjust(hspace=0.05, wspace=0.05) - - if add_prior: - self.add_prior_to_corner(axes, samples_plt) - - fig_triangle.savefig('{}/{}_corner.png'.format( - self.outdir, self.label)) + add_prior=False, nstds=None, label_offset=0.4, + dpi=300, rc_context={}, **kwargs): + + with plt.rc_context(rc_context): + fig, axes = plt.subplots(self.ndim, self.ndim, + figsize=figsize) + + samples_plt = copy.copy(self.samples) + theta_symbols_plt = copy.copy(self.theta_symbols) + theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') for s + in theta_symbols_plt] + + if tglitch_ratio: + for j, k in enumerate(self.theta_keys): + if k == 'tglitch': + s = samples_plt[:, j] + samples_plt[:, j] = (s - self.tstart)/( + self.tend - self.tstart) + theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$' + + if type(nstds) is int and 'range' not in kwargs: + _range = [] + for j, s in enumerate(samples_plt.T): + median = np.median(s) + std = np.std(s) + _range.append((median - nstds*std, median + nstds*std)) + else: + _range = None + + fig_triangle = corner.corner(samples_plt, + labels=theta_symbols_plt, + fig=fig, + bins=50, + max_n_ticks=4, + plot_contours=True, + plot_datapoints=True, + label_kwargs={'fontsize': 8}, + data_kwargs={'alpha': 0.1, + 'ms': 0.5}, + range=_range, + **kwargs) + + axes_list = fig_triangle.get_axes() + axes = np.array(axes_list).reshape(self.ndim, self.ndim) + plt.draw() + for ax in axes[:, 0]: + ax.yaxis.set_label_coords(-label_offset, 0.5) + for ax in axes[-1, :]: + ax.xaxis.set_label_coords(0.5, -label_offset) + for ax in axes_list: + ax.set_rasterized(True) + ax.set_rasterization_zorder(-10) + plt.tight_layout(h_pad=0.0, w_pad=0.0) + fig.subplots_adjust(hspace=0.05, wspace=0.05) + + if add_prior: + self.add_prior_to_corner(axes, samples_plt) + + fig_triangle.savefig('{}/{}_corner.png'.format( + self.outdir, self.label), dpi=dpi) def add_prior_to_corner(self, axes, samples): for i, key in enumerate(self.theta_keys): @@ -757,6 +759,8 @@ class MCMCSearch(BaseSearchClass): return lambda x: logunif(x, kwargs['lower'], kwargs['upper']) elif kwargs['type'] == 'halfnorm': return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale']) + elif kwargs['type'] == 'neghalfnorm': + return lambda x: halfnorm(-x, kwargs['loc'], kwargs['scale']) elif kwargs['type'] == 'norm': return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2 + np.log(2*np.pi*kwargs['scale']**2)) @@ -773,6 +777,9 @@ class MCMCSearch(BaseSearchClass): if dist_type == "halfnorm": return np.abs(np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])) + if dist_type == "neghalfnorm": + return -1 * np.abs(np.random.normal(loc=kwargs['loc'], + scale=kwargs['scale'])) if dist_type == "lognorm": return np.random.lognormal( mean=kwargs['loc'], sigma=kwargs['scale'])