Commit 7fae97b6 authored by David Keitel's avatar David Keitel
Browse files

make transient-fstatmap function user-configurable

 -to later allow for CUDA implementation
 -some optional-import acrobatics
  to fail gracefully if a module (especially CUDA)
  is not available
 -as of this commit, nothing beyond lal implemented yet
parent 793e1f4e
...@@ -50,7 +50,8 @@ search2 = pyfstat.TransientGridSearch( ...@@ -50,7 +50,8 @@ search2 = pyfstat.TransientGridSearch(
minStartTime=minStartTime, maxStartTime=maxStartTime, minStartTime=minStartTime, maxStartTime=maxStartTime,
transientWindowType='rect', t0Band=Tspan-2*Tsft, tauBand=Tspan, transientWindowType='rect', t0Band=Tspan-2*Tsft, tauBand=Tspan,
BSGL=False, BSGL=False,
outputTransientFstatMap=True) outputTransientFstatMap=True,
tCWFstatMapVersion='lal')
search2.run() search2.run()
search2.print_max_twoF() search2.print_max_twoF()
......
...@@ -13,6 +13,7 @@ import scipy.optimize ...@@ -13,6 +13,7 @@ import scipy.optimize
import lal import lal
import lalpulsar import lalpulsar
import pyfstat.helper_functions as helper_functions import pyfstat.helper_functions as helper_functions
import pyfstat.tcw_fstat_map_funcs as tcw
# workaround for matplotlib on X-less remote logins # workaround for matplotlib on X-less remote logins
if 'DISPLAY' in os.environ: if 'DISPLAY' in os.environ:
...@@ -335,7 +336,8 @@ class ComputeFstat(BaseSearchClass): ...@@ -335,7 +336,8 @@ class ComputeFstat(BaseSearchClass):
dt0=None, dtau=None, dt0=None, dtau=None,
detectors=None, minCoverFreq=None, maxCoverFreq=None, detectors=None, minCoverFreq=None, maxCoverFreq=None,
injectSources=None, injectSqrtSX=None, assumeSqrtSX=None, injectSources=None, injectSqrtSX=None, assumeSqrtSX=None,
SSBprec=None): SSBprec=None,
tCWFstatMapVersion='lal'):
""" """
Parameters Parameters
---------- ----------
...@@ -383,6 +385,9 @@ class ComputeFstat(BaseSearchClass): ...@@ -383,6 +385,9 @@ class ComputeFstat(BaseSearchClass):
SSBprec : int SSBprec : int
Flag to set the SSB calculation: 0=Newtonian, 1=relativistic, Flag to set the SSB calculation: 0=Newtonian, 1=relativistic,
2=relativisitic optimised, 3=DMoff, 4=NO_SPIN 2=relativisitic optimised, 3=DMoff, 4=NO_SPIN
tCWFstatMapVersion: str
Choose between standard 'lal' implementation,
'pycuda' for gpu, and some others for devel/debug.
""" """
...@@ -653,6 +658,8 @@ class ComputeFstat(BaseSearchClass): ...@@ -653,6 +658,8 @@ class ComputeFstat(BaseSearchClass):
if self.dtau: if self.dtau:
self.windowRange.dtau = self.dtau self.windowRange.dtau = self.dtau
self.tCWFstatMapFeatures = tcw.init_transient_fstat_map_features()
def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta, def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta,
asini=None, period=None, ecc=None, tp=None, asini=None, period=None, ecc=None, tp=None,
argp=None): argp=None):
...@@ -695,9 +702,21 @@ class ComputeFstat(BaseSearchClass): ...@@ -695,9 +702,21 @@ class ComputeFstat(BaseSearchClass):
# F-stat computation # F-stat computation
self.windowRange.tau = int(2*self.Tsft) self.windowRange.tau = int(2*self.Tsft)
self.FstatMap = lalpulsar.ComputeTransientFstatMap( #logging.debug('Calling "%s" version of ComputeTransientFstatMap() with windowRange: (type=%d (%s), t0=%f, t0Band=%f, dt0=%f, tau=%f, tauBand=%f, dtau=%f)...' % (self.tCWFstatMapVersion, self.windowRange.type, self.transientWindowType, self.windowRange.t0, self.windowRange.t0Band, self.windowRange.dt0, self.windowRange.tau, self.windowRange.tauBand, self.windowRange.dtau))
self.FstatResults.multiFatoms[0], self.windowRange, False) self.FstatMap = tcw.call_compute_transient_fstat_map( self.tCWFstatMapVersion,
F_mn = self.FstatMap.F_mn.data self.tCWFstatMapFeatures,
self.FstatResults.multiFatoms[0],
self.windowRange
)
if self.tCWFstatMapVersion == 'lal':
F_mn = self.FstatMap.F_mn.data
else:
F_mn = self.FstatMap.F_mn
#logging.debug('maxF: {}'.format(FstatMap.maxF))
#logging.debug('t0_ML: %ds=T0+%fd' % (FstatMap.t0_ML, (FstatMap.t0_ML-tstart)/(3600.*24.)))
#logging.debug('tau_ML: %ds=%fd' % (FstatMap.tau_ML, FstatMap.tau_ML/(3600.*24.)))
#logging.debug('F_mn: {}'.format(F_mn))
twoF = 2*np.max(F_mn) twoF = 2*np.max(F_mn)
if self.BSGL is False: if self.BSGL is False:
...@@ -950,6 +969,7 @@ class SemiCoherentSearch(ComputeFstat): ...@@ -950,6 +969,7 @@ class SemiCoherentSearch(ComputeFstat):
self.transientWindowType = 'rect' self.transientWindowType = 'rect'
self.t0Band = None self.t0Band = None
self.tauBand = None self.tauBand = None
self.tCWFstatMapVersion = 'lal'
self.init_computefstatistic_single_point() self.init_computefstatistic_single_point()
self.init_semicoherent_parameters() self.init_semicoherent_parameters()
...@@ -1089,6 +1109,7 @@ class SemiCoherentGlitchSearch(ComputeFstat): ...@@ -1089,6 +1109,7 @@ class SemiCoherentGlitchSearch(ComputeFstat):
self.transientWindowType = 'rect' self.transientWindowType = 'rect'
self.t0Band = None self.t0Band = None
self.tauBand = None self.tauBand = None
self.tCWFstatMapVersion = 'lal'
self.binary = False self.binary = False
self.init_computefstatistic_single_point() self.init_computefstatistic_single_point()
......
...@@ -355,7 +355,8 @@ class TransientGridSearch(GridSearch): ...@@ -355,7 +355,8 @@ class TransientGridSearch(GridSearch):
transientWindowType=None, t0Band=None, tauBand=None, transientWindowType=None, t0Band=None, tauBand=None,
dt0=None, dtau=None, dt0=None, dtau=None,
outputTransientFstatMap=False, outputTransientFstatMap=False,
outputAtoms=False): outputAtoms=False,
tCWFstatMapVersion='lal'):
""" """
Parameters Parameters
---------- ----------
...@@ -388,6 +389,9 @@ class TransientGridSearch(GridSearch): ...@@ -388,6 +389,9 @@ class TransientGridSearch(GridSearch):
outputTransientFstatMap: bool outputTransientFstatMap: bool
if true, write output files for (t0,tau) Fstat maps if true, write output files for (t0,tau) Fstat maps
(one file for each doppler grid point!) (one file for each doppler grid point!)
tCWFstatMapVersion: str
Choose between standard 'lal' implementation,
'pycuda' for gpu, and some others for devel/debug.
For all other parameters, see `pyfstat.ComputeFStat` for details For all other parameters, see `pyfstat.ComputeFStat` for details
""" """
...@@ -413,7 +417,8 @@ class TransientGridSearch(GridSearch): ...@@ -413,7 +417,8 @@ class TransientGridSearch(GridSearch):
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
BSGL=self.BSGL, SSBprec=self.SSBprec, BSGL=self.BSGL, SSBprec=self.SSBprec,
injectSources=self.injectSources, injectSources=self.injectSources,
assumeSqrtSX=self.assumeSqrtSX) assumeSqrtSX=self.assumeSqrtSX,
tCWFstatMapVersion=self.tCWFstatMapVersion)
self.search.get_det_stat = self.search.get_fullycoherent_twoF self.search.get_det_stat = self.search.get_fullycoherent_twoF
def run(self, return_data=False): def run(self, return_data=False):
...@@ -435,9 +440,12 @@ class TransientGridSearch(GridSearch): ...@@ -435,9 +440,12 @@ class TransientGridSearch(GridSearch):
if getattr(self, 'transientWindowType', None): if getattr(self, 'transientWindowType', None):
if self.outputTransientFstatMap: if self.outputTransientFstatMap:
tCWfile = os.path.splitext(self.out_file)[0]+'_tCW_%.16f_%.16f_%.16f_%.16g_%.16g.dat' % (vals[2],vals[5],vals[6],vals[3],vals[4]) # freq alpha delta f1dot f2dot tCWfile = os.path.splitext(self.out_file)[0]+'_tCW_%.16f_%.16f_%.16f_%.16g_%.16g.dat' % (vals[2],vals[5],vals[6],vals[3],vals[4]) # freq alpha delta f1dot f2dot
fo = lal.FileOpen(tCWfile, 'w') if self.tCWFstatMapVersion == 'lal':
lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None ) fo = lal.FileOpen(tCWfile, 'w')
del fo # instead of lal.FileClose() which is not SWIG-exported lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None )
del fo # instead of lal.FileClose() which is not SWIG-exported
else:
np.savetxt(tCWfile, 2.0*FstatMap.F_mn, delimiter=' ')
Fmn = FstatMap.F_mn.data Fmn = FstatMap.F_mn.data
maxidx = np.unravel_index(Fmn.argmax(), Fmn.shape) maxidx = np.unravel_index(Fmn.argmax(), Fmn.shape)
thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0, thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0,
......
...@@ -82,6 +82,9 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -82,6 +82,9 @@ class MCMCSearch(core.BaseSearchClass):
('none' instead of None explicitly calls the transient-window function, ('none' instead of None explicitly calls the transient-window function,
but with the full range, for debugging) but with the full range, for debugging)
Currently only supported for nsegs=1. Currently only supported for nsegs=1.
tCWFstatMapVersion: str
Choose between standard 'lal' implementation,
'pycuda' for gpu, and some others for devel/debug.
Attributes Attributes
---------- ----------
...@@ -115,7 +118,7 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -115,7 +118,7 @@ class MCMCSearch(core.BaseSearchClass):
rhohatmax=1000, binary=False, BSGL=False, rhohatmax=1000, binary=False, BSGL=False,
SSBprec=None, minCoverFreq=None, maxCoverFreq=None, SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
injectSources=None, assumeSqrtSX=None, injectSources=None, assumeSqrtSX=None,
transientWindowType=None): transientWindowType=None, tCWFstatMapVersion='lal'):
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.mkdir(outdir) os.mkdir(outdir)
...@@ -161,7 +164,8 @@ class MCMCSearch(core.BaseSearchClass): ...@@ -161,7 +164,8 @@ class MCMCSearch(core.BaseSearchClass):
transientWindowType=self.transientWindowType, transientWindowType=self.transientWindowType,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
binary=self.binary, injectSources=self.injectSources, binary=self.binary, injectSources=self.injectSources,
assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec) assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec,
tCWFstatMapVersion=self.tCWFstatMapVersion)
if self.minStartTime is None: if self.minStartTime is None:
self.minStartTime = self.search.minStartTime self.minStartTime = self.search.minStartTime
if self.maxStartTime is None: if self.maxStartTime is None:
...@@ -2212,7 +2216,8 @@ class MCMCTransientSearch(MCMCSearch): ...@@ -2212,7 +2216,8 @@ class MCMCTransientSearch(MCMCSearch):
transientWindowType=self.transientWindowType, transientWindowType=self.transientWindowType,
minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
BSGL=self.BSGL, binary=self.binary, BSGL=self.BSGL, binary=self.binary,
injectSources=self.injectSources) injectSources=self.injectSources,
tCWFstatMapVersion=self.tCWFstatMapVersion)
if self.minStartTime is None: if self.minStartTime is None:
self.minStartTime = self.search.minStartTime self.minStartTime = self.search.minStartTime
if self.maxStartTime is None: if self.maxStartTime is None:
......
""" Additional helper functions dealing with transient-CW F(t0,tau) maps """
import logging
# optional imports
import importlib as imp
def optional_import ( modulename, shorthand=None ):
'''
Import a module/submodule only if it's available.
using importlib instead of __import__
because the latter doesn't handle sub.modules
'''
if shorthand is None:
shorthand = modulename
shorthandbit = ''
else:
shorthandbit = ' as '+shorthand
try:
globals()[shorthand] = imp.import_module(modulename)
#logging.debug('Successfully imported module %s%s.' % (modulename, shorthandbit))
success = True
except ImportError, e:
if e.message == 'No module named '+modulename:
logging.warning('No module {:s} found.'.format(modulename))
success = False
else:
raise
return success
# dictionary of the actual callable F-stat map functions we support,
# if the corresponding modules are available.
fstatmap_versions = {
'lal': lambda multiFstatAtoms, windowRange:
getattr(lalpulsar,'ComputeTransientFstatMap')
( multiFstatAtoms, windowRange, False ),
#'pycuda': lambda multiFstatAtoms, windowRange:
#pycuda_compute_transient_fstat_map
#( multiFstatAtoms, windowRange )
}
def init_transient_fstat_map_features ( ):
'''
Initialization of available modules (or "features") for F-stat maps.
Returns a dictionary of method names, to match fstatmap_versions
each key's value set to True only if
all required modules are importable on this system.
'''
features = {}
have_lal = optional_import('lal')
have_lalpulsar = optional_import('lalpulsar')
features['lal'] = have_lal and have_lalpulsar
features['pycuda'] = False
logging.debug('Got the following features for transient F-stat maps:')
logging.debug(features)
return features
def call_compute_transient_fstat_map ( version, features, multiFstatAtoms=None, windowRange=None ):
'''Choose which version of the ComputeTransientFstatMap function to call.'''
if version in fstatmap_versions:
if features[version]:
FstatMap = fstatmap_versions[version](multiFstatAtoms, windowRange)
else:
raise Exception('Required module(s) for transient F-stat map method "{}" not available!'.format(version))
else:
raise Exception('Transient F-stat map method "{}" not implemented!'.format(version))
return FstatMap
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment