Commit 1aa85fa4 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements to the grid search and read params

1) Adds check if the cached data will be used to avoid loading the
search in the grid search
2) Adds better handling of errors in the par read
3) Adds ability to add 2d plots to an existing figure
parent e4a38195
...@@ -85,8 +85,10 @@ def read_par(label, outdir): ...@@ -85,8 +85,10 @@ def read_par(label, outdir):
d = {} d = {}
with open(filename, 'r') as f: with open(filename, 'r') as f:
for line in f: for line in f:
key, val = line.rstrip('\n').split(' = ') if len(line.split('=')) > 1:
d[key] = np.float64(val) key, val = line.rstrip('\n').split(' = ')
key = key.strip()
d[key] = np.float64(val.rstrip('; '))
return d return d
...@@ -1341,8 +1343,8 @@ class GridSearch(BaseSearchClass): ...@@ -1341,8 +1343,8 @@ class GridSearch(BaseSearchClass):
""" """
minStartTime = tstart self.minStartTime = tstart
maxStartTime = tend self.maxStartTime = tend
if sftlabel is None: if sftlabel is None:
self.sftlabel = self.label self.sftlabel = self.label
...@@ -1353,18 +1355,20 @@ class GridSearch(BaseSearchClass): ...@@ -1353,18 +1355,20 @@ class GridSearch(BaseSearchClass):
if sun_ephem is None: if sun_ephem is None:
self.sun_ephem = self.sun_ephem_default self.sun_ephem = self.sun_ephem_default
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
def inititate_search_object(self):
logging.info('Setting up search object')
self.search = ComputeFstat( self.search = ComputeFstat(
tref=self.tref, sftlabel=self.sftlabel, tref=self.tref, sftlabel=self.sftlabel,
sftdir=self.sftdir, minCoverFreq=self.minCoverFreq, sftdir=self.sftdir, minCoverFreq=self.minCoverFreq,
maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem, maxCoverFreq=self.maxCoverFreq, earth_ephem=self.earth_ephem,
sun_ephem=self.sun_ephem, detector=self.detector, transient=False, sun_ephem=self.sun_ephem, detector=self.detector, transient=False,
minStartTime=minStartTime, maxStartTime=maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
BSGL=BSGL) BSGL=self.BSGL)
if os.path.isdir(outdir) is False:
os.mkdir(outdir)
self.out_file = '{}/{}_gridFS.txt'.format(self.outdir, self.label)
self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
def get_array_from_tuple(self, x): def get_array_from_tuple(self, x):
if len(x) == 1: if len(x) == 1:
...@@ -1406,6 +1410,8 @@ class GridSearch(BaseSearchClass): ...@@ -1406,6 +1410,8 @@ class GridSearch(BaseSearchClass):
self.data = old_data self.data = old_data
return return
self.inititate_search_object()
logging.info('Total number of grid points is {}'.format( logging.info('Total number of grid points is {}'.format(
len(self.input_data))) len(self.input_data)))
...@@ -1430,26 +1436,30 @@ class GridSearch(BaseSearchClass): ...@@ -1430,26 +1436,30 @@ class GridSearch(BaseSearchClass):
plt.plot(x, z) plt.plot(x, z)
fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label)) fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
def plot_2D(self, xkey, ykey): def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None):
fig, ax = plt.subplots() if ax is None:
fig, ax = plt.subplots()
xidx = self.keys.index(xkey) xidx = self.keys.index(xkey)
yidx = self.keys.index(ykey) yidx = self.keys.index(ykey)
x = np.unique(self.data[:, xidx]) x = np.unique(self.data[:, xidx])
y = np.unique(self.data[:, yidx]) y = np.unique(self.data[:, yidx])
z = self.data[:, -1] z = self.data[:, -1]
X, Y = np.meshgrid(x, y) Y, X = np.meshgrid(y, x)
Z = z.reshape(X.shape) Z = z.reshape(X.shape)
pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis) pax = ax.pcolormesh(X, Y, Z, cmap=plt.cm.viridis, vmin=vmin, vmax=vmax)
fig.colorbar(pax) plt.colorbar(pax, ax=ax)
ax.set_xlim(x[0], x[-1]) ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1]) ax.set_ylim(y[0], y[-1])
ax.set_xlabel(xkey) ax.set_xlabel(xkey)
ax.set_ylabel(ykey) ax.set_ylabel(ykey)
fig.tight_layout() if save:
fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label)) fig.tight_layout()
fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label))
else:
return ax
def get_max_twoF(self): def get_max_twoF(self):
twoF = self.data[:, -1] twoF = self.data[:, -1]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment