diff --git a/examples/transient_examples/short_transient_search_gridded.py b/examples/transient_examples/short_transient_search_gridded.py index dea9c26745e968bc6af825b61aaeb78b72fef22a..5011b8052ef325bc7b3590195b9112a335b1cdaf 100644 --- a/examples/transient_examples/short_transient_search_gridded.py +++ b/examples/transient_examples/short_transient_search_gridded.py @@ -50,7 +50,8 @@ search2 = pyfstat.TransientGridSearch( minStartTime=minStartTime, maxStartTime=maxStartTime, transientWindowType='rect', t0Band=Tspan-2*Tsft, tauBand=Tspan, BSGL=False, - outputTransientFstatMap=True) + outputTransientFstatMap=True, + tCWFstatMapVersion='lal') search2.run() search2.print_max_twoF() diff --git a/pyfstat/core.py b/pyfstat/core.py index 4482121b3c3b09d9fd3d90182389720010d00f44..abf18565e240dcad542e5446a82b0e51d902a0d8 100755 --- a/pyfstat/core.py +++ b/pyfstat/core.py @@ -13,6 +13,7 @@ import scipy.optimize import lal import lalpulsar import pyfstat.helper_functions as helper_functions +import pyfstat.tcw_fstat_map_funcs as tcw # workaround for matplotlib on X-less remote logins if 'DISPLAY' in os.environ: @@ -335,7 +336,8 @@ class ComputeFstat(BaseSearchClass): dt0=None, dtau=None, detectors=None, minCoverFreq=None, maxCoverFreq=None, injectSources=None, injectSqrtSX=None, assumeSqrtSX=None, - SSBprec=None): + SSBprec=None, + tCWFstatMapVersion='lal'): """ Parameters ---------- @@ -383,6 +385,9 @@ class ComputeFstat(BaseSearchClass): SSBprec : int Flag to set the SSB calculation: 0=Newtonian, 1=relativistic, 2=relativisitic optimised, 3=DMoff, 4=NO_SPIN + tCWFstatMapVersion: str + Choose between standard 'lal' implementation, + 'pycuda' for gpu, and some others for devel/debug. """ @@ -653,6 +658,8 @@ class ComputeFstat(BaseSearchClass): if self.dtau: self.windowRange.dtau = self.dtau + self.tCWFstatMapFeatures = tcw.init_transient_fstat_map_features() + def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta, asini=None, period=None, ecc=None, tp=None, argp=None): @@ -695,9 +702,21 @@ class ComputeFstat(BaseSearchClass): # F-stat computation self.windowRange.tau = int(2*self.Tsft) - self.FstatMap = lalpulsar.ComputeTransientFstatMap( - self.FstatResults.multiFatoms[0], self.windowRange, False) - F_mn = self.FstatMap.F_mn.data + #logging.debug('Calling "%s" version of ComputeTransientFstatMap() with windowRange: (type=%d (%s), t0=%f, t0Band=%f, dt0=%f, tau=%f, tauBand=%f, dtau=%f)...' % (self.tCWFstatMapVersion, self.windowRange.type, self.transientWindowType, self.windowRange.t0, self.windowRange.t0Band, self.windowRange.dt0, self.windowRange.tau, self.windowRange.tauBand, self.windowRange.dtau)) + self.FstatMap = tcw.call_compute_transient_fstat_map( self.tCWFstatMapVersion, + self.tCWFstatMapFeatures, + self.FstatResults.multiFatoms[0], + self.windowRange + ) + if self.tCWFstatMapVersion == 'lal': + F_mn = self.FstatMap.F_mn.data + else: + F_mn = self.FstatMap.F_mn + + #logging.debug('maxF: {}'.format(FstatMap.maxF)) + #logging.debug('t0_ML: %ds=T0+%fd' % (FstatMap.t0_ML, (FstatMap.t0_ML-tstart)/(3600.*24.))) + #logging.debug('tau_ML: %ds=%fd' % (FstatMap.tau_ML, FstatMap.tau_ML/(3600.*24.))) + #logging.debug('F_mn: {}'.format(F_mn)) twoF = 2*np.max(F_mn) if self.BSGL is False: @@ -950,6 +969,7 @@ class SemiCoherentSearch(ComputeFstat): self.transientWindowType = 'rect' self.t0Band = None self.tauBand = None + self.tCWFstatMapVersion = 'lal' self.init_computefstatistic_single_point() self.init_semicoherent_parameters() @@ -1089,6 +1109,7 @@ class SemiCoherentGlitchSearch(ComputeFstat): self.transientWindowType = 'rect' self.t0Band = None self.tauBand = None + self.tCWFstatMapVersion = 'lal' self.binary = False self.init_computefstatistic_single_point() diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py index 42b3c3e9a6e67b63815d59cae2cce9af0870017b..091b2ae2a977f15832ddc23cd0bd8432220ec75a 100644 --- a/pyfstat/grid_based_searches.py +++ b/pyfstat/grid_based_searches.py @@ -355,7 +355,8 @@ class TransientGridSearch(GridSearch): transientWindowType=None, t0Band=None, tauBand=None, dt0=None, dtau=None, outputTransientFstatMap=False, - outputAtoms=False): + outputAtoms=False, + tCWFstatMapVersion='lal'): """ Parameters ---------- @@ -388,6 +389,9 @@ class TransientGridSearch(GridSearch): outputTransientFstatMap: bool if true, write output files for (t0,tau) Fstat maps (one file for each doppler grid point!) + tCWFstatMapVersion: str + Choose between standard 'lal' implementation, + 'pycuda' for gpu, and some others for devel/debug. For all other parameters, see `pyfstat.ComputeFStat` for details """ @@ -413,7 +417,8 @@ class TransientGridSearch(GridSearch): minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, BSGL=self.BSGL, SSBprec=self.SSBprec, injectSources=self.injectSources, - assumeSqrtSX=self.assumeSqrtSX) + assumeSqrtSX=self.assumeSqrtSX, + tCWFstatMapVersion=self.tCWFstatMapVersion) self.search.get_det_stat = self.search.get_fullycoherent_twoF def run(self, return_data=False): @@ -435,9 +440,12 @@ class TransientGridSearch(GridSearch): if getattr(self, 'transientWindowType', None): if self.outputTransientFstatMap: tCWfile = os.path.splitext(self.out_file)[0]+'_tCW_%.16f_%.16f_%.16f_%.16g_%.16g.dat' % (vals[2],vals[5],vals[6],vals[3],vals[4]) # freq alpha delta f1dot f2dot - fo = lal.FileOpen(tCWfile, 'w') - lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None ) - del fo # instead of lal.FileClose() which is not SWIG-exported + if self.tCWFstatMapVersion == 'lal': + fo = lal.FileOpen(tCWfile, 'w') + lalpulsar.write_transientFstatMap_to_fp ( fo, FstatMap, windowRange, None ) + del fo # instead of lal.FileClose() which is not SWIG-exported + else: + np.savetxt(tCWfile, 2.0*FstatMap.F_mn, delimiter=' ') Fmn = FstatMap.F_mn.data maxidx = np.unravel_index(Fmn.argmax(), Fmn.shape) thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0, diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py index b09c7d0397c7133862993bced37bed74eed085ae..99b28b66e5a3bc3d26b51c166abbed4fefc2b15c 100644 --- a/pyfstat/mcmc_based_searches.py +++ b/pyfstat/mcmc_based_searches.py @@ -82,6 +82,9 @@ class MCMCSearch(core.BaseSearchClass): ('none' instead of None explicitly calls the transient-window function, but with the full range, for debugging) Currently only supported for nsegs=1. + tCWFstatMapVersion: str + Choose between standard 'lal' implementation, + 'pycuda' for gpu, and some others for devel/debug. Attributes ---------- @@ -115,7 +118,7 @@ class MCMCSearch(core.BaseSearchClass): rhohatmax=1000, binary=False, BSGL=False, SSBprec=None, minCoverFreq=None, maxCoverFreq=None, injectSources=None, assumeSqrtSX=None, - transientWindowType=None): + transientWindowType=None, tCWFstatMapVersion='lal'): if os.path.isdir(outdir) is False: os.mkdir(outdir) @@ -161,7 +164,8 @@ class MCMCSearch(core.BaseSearchClass): transientWindowType=self.transientWindowType, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, binary=self.binary, injectSources=self.injectSources, - assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec) + assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec, + tCWFstatMapVersion=self.tCWFstatMapVersion) if self.minStartTime is None: self.minStartTime = self.search.minStartTime if self.maxStartTime is None: @@ -2212,7 +2216,8 @@ class MCMCTransientSearch(MCMCSearch): transientWindowType=self.transientWindowType, minStartTime=self.minStartTime, maxStartTime=self.maxStartTime, BSGL=self.BSGL, binary=self.binary, - injectSources=self.injectSources) + injectSources=self.injectSources, + tCWFstatMapVersion=self.tCWFstatMapVersion) if self.minStartTime is None: self.minStartTime = self.search.minStartTime if self.maxStartTime is None: diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..7294e0a6ced84f8776d1bafb7fad8b5766420ebd --- /dev/null +++ b/pyfstat/tcw_fstat_map_funcs.py @@ -0,0 +1,74 @@ +""" Additional helper functions dealing with transient-CW F(t0,tau) maps """ + +import logging + +# optional imports +import importlib as imp + + +def optional_import ( modulename, shorthand=None ): + ''' + Import a module/submodule only if it's available. + + using importlib instead of __import__ + because the latter doesn't handle sub.modules + ''' + if shorthand is None: + shorthand = modulename + shorthandbit = '' + else: + shorthandbit = ' as '+shorthand + try: + globals()[shorthand] = imp.import_module(modulename) + #logging.debug('Successfully imported module %s%s.' % (modulename, shorthandbit)) + success = True + except ImportError, e: + if e.message == 'No module named '+modulename: + logging.warning('No module {:s} found.'.format(modulename)) + success = False + else: + raise + return success + + +# dictionary of the actual callable F-stat map functions we support, +# if the corresponding modules are available. +fstatmap_versions = { + 'lal': lambda multiFstatAtoms, windowRange: + getattr(lalpulsar,'ComputeTransientFstatMap') + ( multiFstatAtoms, windowRange, False ), + #'pycuda': lambda multiFstatAtoms, windowRange: + #pycuda_compute_transient_fstat_map + #( multiFstatAtoms, windowRange ) + } + + +def init_transient_fstat_map_features ( ): + ''' + Initialization of available modules (or "features") for F-stat maps. + + Returns a dictionary of method names, to match fstatmap_versions + each key's value set to True only if + all required modules are importable on this system. + ''' + features = {} + have_lal = optional_import('lal') + have_lalpulsar = optional_import('lalpulsar') + features['lal'] = have_lal and have_lalpulsar + features['pycuda'] = False + logging.debug('Got the following features for transient F-stat maps:') + logging.debug(features) + return features + + +def call_compute_transient_fstat_map ( version, features, multiFstatAtoms=None, windowRange=None ): + '''Choose which version of the ComputeTransientFstatMap function to call.''' + + if version in fstatmap_versions: + if features[version]: + FstatMap = fstatmap_versions[version](multiFstatAtoms, windowRange) + else: + raise Exception('Required module(s) for transient F-stat map method "{}" not available!'.format(version)) + else: + raise Exception('Transient F-stat map method "{}" not implemented!'.format(version)) + return FstatMap