Commit 38fbf98e authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add context to plot_corner to allow the user to customise the plot

parent b26d1c74
......@@ -643,64 +643,66 @@ class MCMCSearch(BaseSearchClass):
self.save_data(sampler, samples, lnprobs, lnlikes)
def plot_corner(self, figsize=(7, 7), tglitch_ratio=False,
add_prior=False, nstds=None, label_offset=0.4, **kwargs):
fig, axes = plt.subplots(self.ndim, self.ndim,
figsize=figsize)
samples_plt = copy.copy(self.samples)
theta_symbols_plt = copy.copy(self.theta_symbols)
theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') for s
in theta_symbols_plt]
if tglitch_ratio:
for j, k in enumerate(self.theta_keys):
if k == 'tglitch':
s = samples_plt[:, j]
samples_plt[:, j] = (s - self.tstart)/(
self.tend - self.tstart)
theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'
if type(nstds) is int and 'range' not in kwargs:
_range = []
for j, s in enumerate(samples_plt.T):
median = np.median(s)
std = np.std(s)
_range.append((median - nstds*std, median + nstds*std))
else:
_range = None
fig_triangle = corner.corner(samples_plt,
labels=theta_symbols_plt,
fig=fig,
bins=50,
max_n_ticks=4,
plot_contours=True,
plot_datapoints=True,
label_kwargs={'fontsize': 8},
data_kwargs={'alpha': 0.1,
'ms': 0.5},
range=_range,
**kwargs)
axes_list = fig_triangle.get_axes()
axes = np.array(axes_list).reshape(self.ndim, self.ndim)
plt.draw()
for ax in axes[:, 0]:
ax.yaxis.set_label_coords(-label_offset, 0.5)
for ax in axes[-1, :]:
ax.xaxis.set_label_coords(0.5, -label_offset)
for ax in axes_list:
ax.set_rasterized(True)
ax.set_rasterization_zorder(-10)
plt.tight_layout(h_pad=0.0, w_pad=0.0)
fig.subplots_adjust(hspace=0.05, wspace=0.05)
if add_prior:
self.add_prior_to_corner(axes, samples_plt)
fig_triangle.savefig('{}/{}_corner.png'.format(
self.outdir, self.label))
add_prior=False, nstds=None, label_offset=0.4,
dpi=300, rc_context={}, **kwargs):
with plt.rc_context(rc_context):
fig, axes = plt.subplots(self.ndim, self.ndim,
figsize=figsize)
samples_plt = copy.copy(self.samples)
theta_symbols_plt = copy.copy(self.theta_symbols)
theta_symbols_plt = [s.replace('_{glitch}', r'_\textrm{glitch}') for s
in theta_symbols_plt]
if tglitch_ratio:
for j, k in enumerate(self.theta_keys):
if k == 'tglitch':
s = samples_plt[:, j]
samples_plt[:, j] = (s - self.tstart)/(
self.tend - self.tstart)
theta_symbols_plt[j] = r'$R_{\textrm{glitch}}$'
if type(nstds) is int and 'range' not in kwargs:
_range = []
for j, s in enumerate(samples_plt.T):
median = np.median(s)
std = np.std(s)
_range.append((median - nstds*std, median + nstds*std))
else:
_range = None
fig_triangle = corner.corner(samples_plt,
labels=theta_symbols_plt,
fig=fig,
bins=50,
max_n_ticks=4,
plot_contours=True,
plot_datapoints=True,
label_kwargs={'fontsize': 8},
data_kwargs={'alpha': 0.1,
'ms': 0.5},
range=_range,
**kwargs)
axes_list = fig_triangle.get_axes()
axes = np.array(axes_list).reshape(self.ndim, self.ndim)
plt.draw()
for ax in axes[:, 0]:
ax.yaxis.set_label_coords(-label_offset, 0.5)
for ax in axes[-1, :]:
ax.xaxis.set_label_coords(0.5, -label_offset)
for ax in axes_list:
ax.set_rasterized(True)
ax.set_rasterization_zorder(-10)
plt.tight_layout(h_pad=0.0, w_pad=0.0)
fig.subplots_adjust(hspace=0.05, wspace=0.05)
if add_prior:
self.add_prior_to_corner(axes, samples_plt)
fig_triangle.savefig('{}/{}_corner.png'.format(
self.outdir, self.label), dpi=dpi)
def add_prior_to_corner(self, axes, samples):
for i, key in enumerate(self.theta_keys):
......@@ -757,6 +759,8 @@ class MCMCSearch(BaseSearchClass):
return lambda x: logunif(x, kwargs['lower'], kwargs['upper'])
elif kwargs['type'] == 'halfnorm':
return lambda x: halfnorm(x, kwargs['loc'], kwargs['scale'])
elif kwargs['type'] == 'neghalfnorm':
return lambda x: halfnorm(-x, kwargs['loc'], kwargs['scale'])
elif kwargs['type'] == 'norm':
return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
+ np.log(2*np.pi*kwargs['scale']**2))
......@@ -773,6 +777,9 @@ class MCMCSearch(BaseSearchClass):
if dist_type == "halfnorm":
return np.abs(np.random.normal(loc=kwargs['loc'],
scale=kwargs['scale']))
if dist_type == "neghalfnorm":
return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
scale=kwargs['scale']))
if dist_type == "lognorm":
return np.random.lognormal(
mean=kwargs['loc'], sigma=kwargs['scale'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment