diff --git a/examples/MCMC_examples/fully_coherent_search_using_MCMC.py b/examples/MCMC_examples/fully_coherent_search_using_MCMC.py
index 22a725539afa60e5a59015dfefc0f02b090887f6..aeb30533b10a43d52fa8ba55508676b71de1c727 100644
--- a/examples/MCMC_examples/fully_coherent_search_using_MCMC.py
+++ b/examples/MCMC_examples/fully_coherent_search_using_MCMC.py
@@ -4,7 +4,7 @@ import numpy as np
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
+duration = 100 * 86400
 tend = tstart + duration
 
 # Properties of the signal
@@ -13,39 +13,46 @@ F1 = -1e-10
 F2 = 0
 Alpha = np.radians(83.6292)
 Delta = np.radians(22.0144)
-tref = .5*(tstart+tend)
+tref = 0.5 * (tstart + tend)
 
 depth = 10
 h0 = sqrtSX / depth
-label = 'fully_coherent_search_using_MCMC'
-outdir = 'data'
+label = "fully_coherent_search_using_MCMC"
+outdir = "data"
 
 data = pyfstat.Writer(
-    label=label, outdir=outdir, tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=label,
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 # The predicted twoF, given by lalapps_predictFstat can be accessed by
 twoF = data.predict_fstat()
-print('Predicted twoF value: {}\n'.format(twoF))
+print("Predicted twoF value: {}\n".format(twoF))
 
 DeltaF0 = 1e-7
 DeltaF1 = 1e-13
-VF0 = (np.pi * duration * DeltaF0)**2 / 3.0
-VF1 = (np.pi * duration**2 * DeltaF1)**2 * 4/45.
-print('\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n'.format(VF0*VF1, VF0, VF1))
+VF0 = (np.pi * duration * DeltaF0) ** 2 / 3.0
+VF1 = (np.pi * duration ** 2 * DeltaF1) ** 2 * 4 / 45.0
+print("\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n".format(VF0 * VF1, VF0, VF1))
 
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2.},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2.},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta
-               }
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2.0},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 2
 log10beta_min = -0.5
@@ -53,13 +60,22 @@ nwalkers = 100
 nsteps = [300, 300]
 
 mcmc = pyfstat.MCMCSearch(
-    label=label, outdir=outdir,
-    sftfilepattern='{}/*{}*sft'.format(outdir, label), theta_prior=theta_prior,
-    tref=tref, minStartTime=tstart, maxStartTime=tend, nsteps=nsteps,
-    nwalkers=nwalkers, ntemps=ntemps, log10beta_min=log10beta_min)
+    label=label,
+    outdir=outdir,
+    sftfilepattern="{}/*{}*sft".format(outdir, label),
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tend,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+)
 mcmc.transform_dictionary = dict(
-    F0=dict(subtractor=F0, symbol='$f-f^\mathrm{s}$'),
-    F1=dict(subtractor=F1, symbol='$\dot{f}-\dot{f}^\mathrm{s}$'))
+    F0=dict(subtractor=F0, symbol="$f-f^\mathrm{s}$"),
+    F1=dict(subtractor=F1, symbol="$\dot{f}-\dot{f}^\mathrm{s}$"),
+)
 mcmc.run()
 mcmc.plot_corner(add_prior=True)
 mcmc.print_summary()
diff --git a/examples/MCMC_examples/semi_coherent_search_using_MCMC.py b/examples/MCMC_examples/semi_coherent_search_using_MCMC.py
index 556e3b14b660a2722c61baa848174d821f0a5dd5..d357bf15a477be124223466b3161809f0cade89f 100644
--- a/examples/MCMC_examples/semi_coherent_search_using_MCMC.py
+++ b/examples/MCMC_examples/semi_coherent_search_using_MCMC.py
@@ -4,7 +4,7 @@ import numpy as np
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
+duration = 100 * 86400
 tend = tstart + duration
 
 # Properties of the signal
@@ -13,39 +13,46 @@ F1 = -1e-10
 F2 = 0
 Alpha = np.radians(83.6292)
 Delta = np.radians(22.0144)
-tref = .5*(tstart+tend)
+tref = 0.5 * (tstart + tend)
 
 depth = 10
 h0 = sqrtSX / depth
-label = 'semicoherent_search_using_MCMC'
-outdir = 'data'
+label = "semicoherent_search_using_MCMC"
+outdir = "data"
 
 data = pyfstat.Writer(
-    label=label, outdir=outdir, tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=label,
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 # The predicted twoF, given by lalapps_predictFstat can be accessed by
 twoF = data.predict_fstat()
-print('Predicted twoF value: {}\n'.format(twoF))
+print("Predicted twoF value: {}\n".format(twoF))
 
 DeltaF0 = 1e-7
 DeltaF1 = 1e-13
-VF0 = (np.pi * duration * DeltaF0)**2 / 3.0
-VF1 = (np.pi * duration**2 * DeltaF1)**2 * 4/45.
-print('\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n'.format(VF0*VF1, VF0, VF1))
+VF0 = (np.pi * duration * DeltaF0) ** 2 / 3.0
+VF1 = (np.pi * duration ** 2 * DeltaF1) ** 2 * 4 / 45.0
+print("\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n".format(VF0 * VF1, VF0, VF1))
 
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2.},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2.},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta
-               }
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2.0},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 1
 log10beta_min = -1
@@ -53,14 +60,23 @@ nwalkers = 100
 nsteps = [300, 300]
 
 mcmc = pyfstat.MCMCSemiCoherentSearch(
-    label=label, outdir=outdir, nsegs=10,
-    sftfilepattern='{}/*{}*sft'.format(outdir, label),
-    theta_prior=theta_prior, tref=tref, minStartTime=tstart, maxStartTime=tend,
-    nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
-    log10beta_min=log10beta_min)
+    label=label,
+    outdir=outdir,
+    nsegs=10,
+    sftfilepattern="{}/*{}*sft".format(outdir, label),
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tend,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+)
 mcmc.transform_dictionary = dict(
-    F0=dict(subtractor=F0, symbol='$f-f^\mathrm{s}$'),
-    F1=dict(subtractor=F1, symbol='$\dot{f}-\dot{f}^\mathrm{s}$'))
+    F0=dict(subtractor=F0, symbol="$f-f^\mathrm{s}$"),
+    F1=dict(subtractor=F1, symbol="$\dot{f}-\dot{f}^\mathrm{s}$"),
+)
 mcmc.run()
 mcmc.plot_corner(add_prior=True)
 mcmc.print_summary()
diff --git a/examples/followup_examples/semi_coherent_directed_follow_up.py b/examples/followup_examples/semi_coherent_directed_follow_up.py
index e1db02603ffb398fb5eb0a51df91d2e8be1d41a9..d1d9016daf9bc4abc64627f9f7be710e9fa4ad69 100644
--- a/examples/followup_examples/semi_coherent_directed_follow_up.py
+++ b/examples/followup_examples/semi_coherent_directed_follow_up.py
@@ -11,39 +11,47 @@ Delta = np.radians(22.0144)
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
-tend = tstart+duration
-tref = .5*(tstart+tend)
+duration = 100 * 86400
+tend = tstart + duration
+tref = 0.5 * (tstart + tend)
 
 depth = 40
-label = 'semicoherent_directed_follow_up'
-outdir = 'data'
+label = "semicoherent_directed_follow_up"
+outdir = "data"
 
 h0 = sqrtSX / depth
 
 data = pyfstat.Writer(
-    label=label, outdir=outdir, tref=tref, tstart=tstart, F0=F0, F1=F1,
-    F2=F2, duration=duration, Alpha=Alpha, Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=label,
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 # The predicted twoF, given by lalapps_predictFstat can be accessed by
 twoF = data.predict_fstat()
-print('Predicted twoF value: {}\n'.format(twoF))
+print("Predicted twoF value: {}\n".format(twoF))
 
 # Search
 VF0 = VF1 = 1e5
-DeltaF0 = np.sqrt(VF0) * np.sqrt(3)/(np.pi*duration)
-DeltaF1 = np.sqrt(VF1) * np.sqrt(180)/(np.pi*duration**2)
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta
-               }
+DeltaF0 = np.sqrt(VF0) * np.sqrt(3) / (np.pi * duration)
+DeltaF1 = np.sqrt(VF1) * np.sqrt(180) / (np.pi * duration ** 2)
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 3
 log10beta_min = -0.5
@@ -51,23 +59,35 @@ nwalkers = 100
 nsteps = [100, 100]
 
 mcmc = pyfstat.MCMCFollowUpSearch(
-    label=label, outdir=outdir,
-    sftfilepattern='{}/*{}*sft'.format(outdir, label),
-    theta_prior=theta_prior, tref=tref, minStartTime=tstart, maxStartTime=tend,
-    nwalkers=nwalkers, nsteps=nsteps, ntemps=ntemps,
-    log10beta_min=log10beta_min)
+    label=label,
+    outdir=outdir,
+    sftfilepattern="{}/*{}*sft".format(outdir, label),
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tend,
+    nwalkers=nwalkers,
+    nsteps=nsteps,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+)
 
 NstarMax = 1000
 Nsegs0 = 100
 fig, axes = plt.subplots(nrows=2, figsize=(3.4, 3.5))
 fig, axes = mcmc.run(
-    NstarMax=NstarMax, Nsegs0=Nsegs0, labelpad=0.01,
-    plot_det_stat=False, return_fig=True, fig=fig,
-    axes=axes)
+    NstarMax=NstarMax,
+    Nsegs0=Nsegs0,
+    labelpad=0.01,
+    plot_det_stat=False,
+    return_fig=True,
+    fig=fig,
+    axes=axes,
+)
 for ax in axes:
     ax.grid()
     ax.set_xticks(np.arange(0, 600, 100))
     ax.set_xticklabels([str(s) for s in np.arange(0, 700, 100)])
-axes[-1].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.1)
+axes[-1].set_xlabel(r"$\textrm{Number of steps}$", labelpad=0.1)
 fig.tight_layout()
-fig.savefig('{}/{}_walkers.png'.format(mcmc.outdir, mcmc.label), dpi=400)
+fig.savefig("{}/{}_walkers.png".format(mcmc.outdir, mcmc.label), dpi=400)
diff --git a/examples/glitch_examples/make_simulated_data.py b/examples/glitch_examples/make_simulated_data.py
index e52207b6285f2c2c2ddeaa0e85b6a2caeed45fc6..831b792624a8a3d1e6128d6b936f43ab12a51fad 100644
--- a/examples/glitch_examples/make_simulated_data.py
+++ b/examples/glitch_examples/make_simulated_data.py
@@ -1,7 +1,7 @@
 from pyfstat import Writer, GlitchWriter
 import numpy as np
 
-outdir = 'data'
+outdir = "data"
 # First, we generate data with a reasonably strong smooth signal
 
 # Define parameters of the Crab pulsar as an example
@@ -17,37 +17,75 @@ h0 = 5e-24
 # Properties of the GW data
 sqrtSX = 1e-22
 tstart = 1000000000
-duration = 50*86400
-tend = tstart+duration
-tref = tstart + 0.5*duration
+duration = 50 * 86400
+tend = tstart + duration
+tref = tstart + 0.5 * duration
 
 data = Writer(
-    label='0_glitch', outdir=outdir, tref=tref, tstart=tstart, F0=F0, F1=F1,
-    F2=F2, duration=duration, Alpha=Alpha, Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label="0_glitch",
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 # Next, taking the same signal parameters, we include a glitch half way through
-dtglitch = duration/2.0
+dtglitch = duration / 2.0
 delta_F0 = 5e-6
 delta_F1 = 0
 
 glitch_data = GlitchWriter(
-    label='1_glitch', outdir=outdir, tref=tref, tstart=tstart, F0=F0, F1=F1,
-    F2=F2, duration=duration, Alpha=Alpha, Delta=Delta, h0=h0, sqrtSX=sqrtSX,
-    dtglitch=dtglitch, delta_F0=delta_F0, delta_F1=delta_F1)
+    label="1_glitch",
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+    dtglitch=dtglitch,
+    delta_F0=delta_F0,
+    delta_F1=delta_F1,
+)
 glitch_data.make_data()
 
 # Making data with two glitches
 
-dtglitch_2 = [duration/4.0, 4*duration/5.0]
+dtglitch_2 = [duration / 4.0, 4 * duration / 5.0]
 delta_phi_2 = [0, 0]
 delta_F0_2 = [4e-6, 3e-7]
 delta_F1_2 = [0, 0]
 delta_F2_2 = [0, 0]
 
 two_glitch_data = GlitchWriter(
-    label='2_glitch', outdir=outdir, tref=tref, tstart=tstart, F0=F0, F1=F1,
-    F2=F2, duration=duration, Alpha=Alpha, Delta=Delta, h0=h0, sqrtSX=sqrtSX,
-    dtglitch=dtglitch_2, delta_phi=delta_phi_2, delta_F0=delta_F0_2,
-    delta_F1=delta_F1_2, delta_F2=delta_F2_2)
+    label="2_glitch",
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+    dtglitch=dtglitch_2,
+    delta_phi=delta_phi_2,
+    delta_F0=delta_F0_2,
+    delta_F1=delta_F1_2,
+    delta_F2=delta_F2_2,
+)
 two_glitch_data.make_data()
diff --git a/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py b/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py
index e0613d6db641f2854a307ba911f196fb9d29da87..361f43bce9171b6ad9dfde5fd05383109b9776d7 100644
--- a/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py
+++ b/examples/glitch_examples/semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch.py
@@ -3,34 +3,42 @@ import matplotlib.pyplot as plt
 import pyfstat
 import gridcorner
 import time
-from make_simulated_data import tstart, duration, tref, F0, F1, F2, Alpha, Delta, delta_F0, dtglitch, outdir
+from make_simulated_data import (
+    tstart,
+    duration,
+    tref,
+    F0,
+    F1,
+    F2,
+    Alpha,
+    Delta,
+    delta_F0,
+    dtglitch,
+    outdir,
+)
 
-plt.style.use('./paper.mplstyle')
+plt.style.use("./paper.mplstyle")
 
-label = 'semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch'
+label = "semicoherent_glitch_robust_directed_MCMC_search_on_1_glitch"
 
 Nstar = 1000
-F0_width = np.sqrt(Nstar)*np.sqrt(12)/(np.pi*duration)
-F1_width = np.sqrt(Nstar)*np.sqrt(180)/(np.pi*duration**2)
+F0_width = np.sqrt(Nstar) * np.sqrt(12) / (np.pi * duration)
+F1_width = np.sqrt(Nstar) * np.sqrt(180) / (np.pi * duration ** 2)
 
 theta_prior = {
-    'F0': {'type': 'unif',
-           'lower': F0-F0_width/2.,
-           'upper': F0+F0_width/2.},
-    'F1': {'type': 'unif',
-           'lower': F1-F1_width/2.,
-           'upper': F1+F1_width/2.},
-    'F2': F2,
-    'delta_F0': {'type': 'unif',
-                 'lower': 0,
-                 'upper': 1e-5},
-    'delta_F1': 0,
-    'tglitch': {'type': 'unif',
-                'lower': tstart+0.1*duration,
-                'upper': tstart+0.9*duration},
-    'Alpha': Alpha,
-    'Delta': Delta,
-    }
+    "F0": {"type": "unif", "lower": F0 - F0_width / 2.0, "upper": F0 + F0_width / 2.0},
+    "F1": {"type": "unif", "lower": F1 - F1_width / 2.0, "upper": F1 + F1_width / 2.0},
+    "F2": F2,
+    "delta_F0": {"type": "unif", "lower": 0, "upper": 1e-5},
+    "delta_F1": 0,
+    "tglitch": {
+        "type": "unif",
+        "lower": tstart + 0.1 * duration,
+        "upper": tstart + 0.9 * duration,
+    },
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 3
 log10beta_min = -0.5
@@ -38,33 +46,49 @@ nwalkers = 100
 nsteps = [250, 250]
 
 mcmc = pyfstat.MCMCGlitchSearch(
-    label=label, sftfilepattern='data/*1_glitch*sft', theta_prior=theta_prior,
-    tref=tref, minStartTime=tstart, maxStartTime=tstart+duration,
-    nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
-    log10beta_min=log10beta_min, nglitch=1)
-mcmc.transform_dictionary['F0'] = dict(
-    subtractor=F0, multiplier=1e6, symbol='$f-f_\mathrm{s}$')
-mcmc.unit_dictionary['F0'] = '$\mu$Hz'
-mcmc.transform_dictionary['F1'] = dict(
-    subtractor=F1, multiplier=1e12, symbol='$\dot{f}-\dot{f}_\mathrm{s}$')
-mcmc.unit_dictionary['F1'] = '$p$Hz/s'
-mcmc.transform_dictionary['delta_F0'] = dict(
-    multiplier=1e6, subtractor=delta_F0,
-    symbol='$\delta f-\delta f_\mathrm{s}$')
-mcmc.unit_dictionary['delta_F0'] = '$\mu$Hz/s'
-mcmc.transform_dictionary['tglitch']['subtractor'] = tstart + dtglitch
-mcmc.transform_dictionary['tglitch']['label'] ='$t^\mathrm{g}-t^\mathrm{g}_\mathrm{s}$\n[d]'
+    label=label,
+    sftfilepattern="data/*1_glitch*sft",
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tstart + duration,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+    nglitch=1,
+)
+mcmc.transform_dictionary["F0"] = dict(
+    subtractor=F0, multiplier=1e6, symbol="$f-f_\mathrm{s}$"
+)
+mcmc.unit_dictionary["F0"] = "$\mu$Hz"
+mcmc.transform_dictionary["F1"] = dict(
+    subtractor=F1, multiplier=1e12, symbol="$\dot{f}-\dot{f}_\mathrm{s}$"
+)
+mcmc.unit_dictionary["F1"] = "$p$Hz/s"
+mcmc.transform_dictionary["delta_F0"] = dict(
+    multiplier=1e6, subtractor=delta_F0, symbol="$\delta f-\delta f_\mathrm{s}$"
+)
+mcmc.unit_dictionary["delta_F0"] = "$\mu$Hz/s"
+mcmc.transform_dictionary["tglitch"]["subtractor"] = tstart + dtglitch
+mcmc.transform_dictionary["tglitch"][
+    "label"
+] = "$t^\mathrm{g}-t^\mathrm{g}_\mathrm{s}$\n[d]"
 
 t1 = time.time()
 mcmc.run()
 dT = time.time() - t1
 fig_and_axes = gridcorner._get_fig_and_axes(4, 2, 0.05)
-mcmc.plot_corner(label_offset=0.25, truths=[0, 0, 0, 0],
-                 fig_and_axes=fig_and_axes, quantiles=(0.16, 0.84),
-                 hist_kwargs=dict(lw=1.5, zorder=-1),
-                 truth_color='C3')
+mcmc.plot_corner(
+    label_offset=0.25,
+    truths=[0, 0, 0, 0],
+    fig_and_axes=fig_and_axes,
+    quantiles=(0.16, 0.84),
+    hist_kwargs=dict(lw=1.5, zorder=-1),
+    truth_color="C3",
+)
 
 mcmc.print_summary()
 
-print(('Prior widths =', F0_width, F1_width))
+print(("Prior widths =", F0_width, F1_width))
 print(("Actual run time = {}".format(dT)))
diff --git a/examples/glitch_examples/semicoherent_glitch_robust_directed_grid_search_on_1_glitch.py b/examples/glitch_examples/semicoherent_glitch_robust_directed_grid_search_on_1_glitch.py
index f51b73f2dcc267ac8a21aa70b95c187d8d4c6e8a..0a09203a4d1d1462d4f525299d5e88c35e8d689a 100644
--- a/examples/glitch_examples/semicoherent_glitch_robust_directed_grid_search_on_1_glitch.py
+++ b/examples/glitch_examples/semicoherent_glitch_robust_directed_grid_search_on_1_glitch.py
@@ -1,7 +1,19 @@
 import pyfstat
 import numpy as np
 import matplotlib.pyplot as plt
-from make_simulated_data import tstart, duration, tref, F0, F1, F2, Alpha, Delta, delta_F0, outdir, dtglitch
+from make_simulated_data import (
+    tstart,
+    duration,
+    tref,
+    F0,
+    F1,
+    F2,
+    Alpha,
+    Delta,
+    delta_F0,
+    outdir,
+    dtglitch,
+)
 import time
 
 try:
@@ -9,34 +21,46 @@ try:
 except ImportError:
     raise ImportError(
         "Python module 'gridcorner' not found, please install from "
-        "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner")
+        "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner"
+    )
 
-label = 'semicoherent_glitch_robust_directed_grid_search_on_1_glitch'
+label = "semicoherent_glitch_robust_directed_grid_search_on_1_glitch"
 
-plt.style.use('./paper.mplstyle')
+plt.style.use("./paper.mplstyle")
 
 Nstar = 1000
-F0_width = np.sqrt(Nstar)*np.sqrt(12)/(np.pi*duration)
-F1_width = np.sqrt(Nstar)*np.sqrt(180)/(np.pi*duration**2)
+F0_width = np.sqrt(Nstar) * np.sqrt(12) / (np.pi * duration)
+F1_width = np.sqrt(Nstar) * np.sqrt(180) / (np.pi * duration ** 2)
 N = 20
-F0s = [F0-F0_width/2., F0+F0_width/2., F0_width/N]
-F1s = [F1-F1_width/2., F1+F1_width/2., F1_width/N]
+F0s = [F0 - F0_width / 2.0, F0 + F0_width / 2.0, F0_width / N]
+F1s = [F1 - F1_width / 2.0, F1 + F1_width / 2.0, F1_width / N]
 F2s = [F2]
 Alphas = [Alpha]
 Deltas = [Delta]
 
 max_delta_F0 = 1e-5
-tglitchs = [tstart+0.1*duration, tstart+0.9*duration, 0.8*float(duration)/N]
-delta_F0s = [0, max_delta_F0, max_delta_F0/N]
+tglitchs = [tstart + 0.1 * duration, tstart + 0.9 * duration, 0.8 * float(duration) / N]
+delta_F0s = [0, max_delta_F0, max_delta_F0 / N]
 delta_F1s = [0]
 
 
 t1 = time.time()
 search = pyfstat.GridGlitchSearch(
-    label, outdir, 'data/*1_glitch*sft', F0s=F0s, F1s=F1s, F2s=F2s,
-    Alphas=Alphas, Deltas=Deltas, tref=tref, minStartTime=tstart,
-    maxStartTime=tstart+duration, tglitchs=tglitchs, delta_F0s=delta_F0s,
-    delta_F1s=delta_F1s)
+    label,
+    outdir,
+    "data/*1_glitch*sft",
+    F0s=F0s,
+    F1s=F1s,
+    F2s=F2s,
+    Alphas=Alphas,
+    Deltas=Deltas,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tstart + duration,
+    tglitchs=tglitchs,
+    delta_F0s=delta_F0s,
+    delta_F1s=delta_F1s,
+)
 search.run()
 dT = time.time() - t1
 
@@ -44,24 +68,32 @@ F0_vals = np.unique(search.data[:, 0]) - F0
 F1_vals = np.unique(search.data[:, 1]) - F1
 delta_F0s_vals = np.unique(search.data[:, 5]) - delta_F0
 tglitch_vals = np.unique(search.data[:, 7])
-tglitch_vals_days = (tglitch_vals-tstart) / 86400. - dtglitch/86400.
+tglitch_vals_days = (tglitch_vals - tstart) / 86400.0 - dtglitch / 86400.0
 
-twoF = search.data[:, -1].reshape((len(F0_vals), len(F1_vals),
-                                   len(delta_F0s_vals), len(tglitch_vals)))
-xyz = [F0_vals*1e6, F1_vals*1e12, delta_F0s_vals*1e6, tglitch_vals_days]
-labels = ['$f - f_\mathrm{s}$\n[$\mu$Hz]',
-          '$\dot{f} - \dot{f}_\mathrm{s}$\n[$p$Hz/s]',
-          '$\delta f-\delta f_\mathrm{s}$\n[$\mu$Hz]',
-          '$t^\mathrm{g} - t^\mathrm{g}_\mathrm{s}$\n[d]',
-          '$\widehat{2\mathcal{F}}$']
+twoF = search.data[:, -1].reshape(
+    (len(F0_vals), len(F1_vals), len(delta_F0s_vals), len(tglitch_vals))
+)
+xyz = [F0_vals * 1e6, F1_vals * 1e12, delta_F0s_vals * 1e6, tglitch_vals_days]
+labels = [
+    "$f - f_\mathrm{s}$\n[$\mu$Hz]",
+    "$\dot{f} - \dot{f}_\mathrm{s}$\n[$p$Hz/s]",
+    "$\delta f-\delta f_\mathrm{s}$\n[$\mu$Hz]",
+    "$t^\mathrm{g} - t^\mathrm{g}_\mathrm{s}$\n[d]",
+    "$\widehat{2\mathcal{F}}$",
+]
 fig, axes = gridcorner(
-    twoF, xyz, projection='log_mean', labels=labels,
-    showDvals=False, lines=[0, 0, 0, 0], label_offset=0.25,
-    max_n_ticks=4)
-fig.savefig('{}/{}_projection_matrix.png'.format(outdir, label),
-            bbox_inches='tight')
+    twoF,
+    xyz,
+    projection="log_mean",
+    labels=labels,
+    showDvals=False,
+    lines=[0, 0, 0, 0],
+    label_offset=0.25,
+    max_n_ticks=4,
+)
+fig.savefig("{}/{}_projection_matrix.png".format(outdir, label), bbox_inches="tight")
 
 
-print(('Prior widths =', F0_width, F1_width))
+print(("Prior widths =", F0_width, F1_width))
 print(("Actual run time = {}".format(dT)))
 print(("Actual number of grid points = {}".format(search.data.shape[0])))
diff --git a/examples/glitch_examples/standard_directed_MCMC_search_on_1_glitch.py b/examples/glitch_examples/standard_directed_MCMC_search_on_1_glitch.py
index 7856a59557bd9b30acbbe4af58558e59c0e5f9f6..853ccb7c05d51cbf9d84e708e9f5baf8597b1292 100644
--- a/examples/glitch_examples/standard_directed_MCMC_search_on_1_glitch.py
+++ b/examples/glitch_examples/standard_directed_MCMC_search_on_1_glitch.py
@@ -3,25 +3,21 @@ import matplotlib.pyplot as plt
 import pyfstat
 from make_simulated_data import tstart, duration, tref, F0, F1, F2, Alpha, Delta, outdir
 
-plt.style.use('paper')
+plt.style.use("paper")
 
-label = 'standard_directed_MCMC_search_on_1_glitch'
+label = "standard_directed_MCMC_search_on_1_glitch"
 
 Nstar = 10000
-F0_width = np.sqrt(Nstar)*np.sqrt(12)/(np.pi*duration)
-F1_width = np.sqrt(Nstar)*np.sqrt(180)/(np.pi*duration**2)
+F0_width = np.sqrt(Nstar) * np.sqrt(12) / (np.pi * duration)
+F1_width = np.sqrt(Nstar) * np.sqrt(180) / (np.pi * duration ** 2)
 
 theta_prior = {
-    'F0': {'type': 'unif',
-           'lower': F0-F0_width/2.,
-           'upper': F0+F0_width/2.},
-    'F1': {'type': 'unif',
-           'lower': F1-F1_width/2.,
-           'upper': F1+F1_width/2.},
-    'F2': F2,
-    'Alpha': Alpha,
-    'Delta': Delta,
-    }
+    "F0": {"type": "unif", "lower": F0 - F0_width / 2.0, "upper": F0 + F0_width / 2.0},
+    "F1": {"type": "unif", "lower": F1 - F1_width / 2.0, "upper": F1 + F1_width / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 2
 log10beta_min = -0.5
@@ -29,15 +25,22 @@ nwalkers = 100
 nsteps = [500, 2000]
 
 mcmc = pyfstat.MCMCSearch(
-    label=label, sftfilepattern='data/*1_glitch*sft', theta_prior=theta_prior,
-    tref=tref, minStartTime=tstart, maxStartTime=tstart+duration,
-    nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
-    log10beta_min=log10beta_min)
-
-mcmc.transform_dictionary['F0'] = dict(
-    subtractor=F0, symbol='$f-f^\mathrm{s}$')
-mcmc.transform_dictionary['F1'] = dict(
-    subtractor=F1, symbol='$\dot{f}-\dot{f}^\mathrm{s}$')
+    label=label,
+    sftfilepattern="data/*1_glitch*sft",
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tstart + duration,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+)
+
+mcmc.transform_dictionary["F0"] = dict(subtractor=F0, symbol="$f-f^\mathrm{s}$")
+mcmc.transform_dictionary["F1"] = dict(
+    subtractor=F1, symbol="$\dot{f}-\dot{f}^\mathrm{s}$"
+)
 
 mcmc.run()
 mcmc.plot_corner()
diff --git a/examples/grid_examples/grid_F0F1F2.py b/examples/grid_examples/grid_F0F1F2.py
index 2de31067e42c890912bfa179fa98ade1490105da..17c97ccdf28d9f4b1f84da46f9009b320c517178 100644
--- a/examples/grid_examples/grid_F0F1F2.py
+++ b/examples/grid_examples/grid_F0F1F2.py
@@ -7,7 +7,8 @@ try:
 except ImportError:
     raise ImportError(
         "Python module 'gridcorner' not found, please install from "
-        "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner")
+        "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner"
+    )
 
 F0 = 30.0
 F1 = 1e-10
@@ -18,38 +19,58 @@ Delta = 1.5
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 10*86400
-tend = tstart+duration
-tref = .5*(tstart+tend)
+duration = 10 * 86400
+tend = tstart + duration
+tref = 0.5 * (tstart + tend)
 
 depth = 20
-label = 'grid_F0F1F2'
-outdir = 'data'
+label = "grid_F0F1F2"
+outdir = "data"
 
 h0 = sqrtSX / depth
 
 data = pyfstat.Writer(
-    label=label, outdir=outdir, tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=label,
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 m = 0.01
-dF0 = np.sqrt(12*m)/(np.pi*duration)
-dF1 = np.sqrt(180*m)/(np.pi*duration**2)
+dF0 = np.sqrt(12 * m) / (np.pi * duration)
+dF1 = np.sqrt(180 * m) / (np.pi * duration ** 2)
 dF2 = 1e-17
 N = 100
-DeltaF0 = N*dF0
-DeltaF1 = N*dF1
-DeltaF2 = N*dF2
-F0s = [F0-DeltaF0/2., F0+DeltaF0/2., dF0]
-F1s = [F1-DeltaF1/2., F1+DeltaF1/2., dF1]
-F2s = [F2-DeltaF2/2., F2+DeltaF2/2., dF2]
+DeltaF0 = N * dF0
+DeltaF1 = N * dF1
+DeltaF2 = N * dF2
+F0s = [F0 - DeltaF0 / 2.0, F0 + DeltaF0 / 2.0, dF0]
+F1s = [F1 - DeltaF1 / 2.0, F1 + DeltaF1 / 2.0, dF1]
+F2s = [F2 - DeltaF2 / 2.0, F2 + DeltaF2 / 2.0, dF2]
 Alphas = [Alpha]
 Deltas = [Delta]
 search = pyfstat.GridSearch(
-    'grid_F0F1F2', 'data', data.sftfilepath, F0s, F1s,
-    F2s, Alphas, Deltas, tref, tstart, tend)
+    "grid_F0F1F2",
+    "data",
+    data.sftfilepath,
+    F0s,
+    F1s,
+    F2s,
+    Alphas,
+    Deltas,
+    tref,
+    tstart,
+    tend,
+)
 search.run()
 
 F0_vals = np.unique(search.data[:, 2]) - F0
@@ -57,8 +78,13 @@ F1_vals = np.unique(search.data[:, 3]) - F1
 F2_vals = np.unique(search.data[:, 4]) - F2
 twoF = search.data[:, -1].reshape((len(F0_vals), len(F1_vals), len(F2_vals)))
 xyz = [F0_vals, F1_vals, F2_vals]
-labels = ['$f - f_0$', '$\dot{f} - \dot{f}_0$', '$\ddot{f} - \ddot{f}_0$',
-          '$\widetilde{2\mathcal{F}}$']
+labels = [
+    "$f - f_0$",
+    "$\dot{f} - \dot{f}_0$",
+    "$\ddot{f} - \ddot{f}_0$",
+    "$\widetilde{2\mathcal{F}}$",
+]
 fig, axes = gridcorner(
-    twoF, xyz, projection='log_mean', labels=labels, whspace=0.1, factor=1.8)
-fig.savefig('{}/{}_projection_matrix.png'.format(outdir, label))
+    twoF, xyz, projection="log_mean", labels=labels, whspace=0.1, factor=1.8
+)
+fig.savefig("{}/{}_projection_matrix.png".format(outdir, label))
diff --git a/examples/grid_examples/grided_frequency_search.py b/examples/grid_examples/grided_frequency_search.py
index e8f5d55f1a362c79bb631d2ea2b5dd428d3550f7..306d443b6b2b4efb15f6d0f76c186a8419b721ff 100644
--- a/examples/grid_examples/grided_frequency_search.py
+++ b/examples/grid_examples/grided_frequency_search.py
@@ -11,52 +11,72 @@ Delta = 1.5
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
-tend = tstart+duration
-tref = .5*(tstart+tend)
+duration = 100 * 86400
+tend = tstart + duration
+tref = 0.5 * (tstart + tend)
 
 depth = 70
-data_label = 'grided_frequency_depth_{:1.0f}'.format(depth)
+data_label = "grided_frequency_depth_{:1.0f}".format(depth)
 
 h0 = sqrtSX / depth
 
 data = pyfstat.Writer(
-    label=data_label, outdir='data', tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=data_label,
+    outdir="data",
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 m = 0.001
-dF0 = np.sqrt(12*m)/(np.pi*duration)
-DeltaF0 = 800*dF0
-F0s = [F0-DeltaF0/2., F0+DeltaF0/2., dF0]
+dF0 = np.sqrt(12 * m) / (np.pi * duration)
+DeltaF0 = 800 * dF0
+F0s = [F0 - DeltaF0 / 2.0, F0 + DeltaF0 / 2.0, dF0]
 F1s = [F1]
 F2s = [F2]
 Alphas = [Alpha]
 Deltas = [Delta]
 search = pyfstat.GridSearch(
-    'grided_frequency_search', 'data', 'data/*'+data_label+'*sft', F0s, F1s,
-    F2s, Alphas, Deltas, tref, tstart, tend)
+    "grided_frequency_search",
+    "data",
+    "data/*" + data_label + "*sft",
+    F0s,
+    F1s,
+    F2s,
+    Alphas,
+    Deltas,
+    tref,
+    tstart,
+    tend,
+)
 search.run()
 
 fig, ax = plt.subplots()
-xidx = search.keys.index('F0')
+xidx = search.keys.index("F0")
 frequencies = np.unique(search.data[:, xidx])
 twoF = search.data[:, -1]
 
-#mismatch = np.sign(x-F0)*(duration * np.pi * (x - F0))**2 / 12.0
-ax.plot(frequencies, twoF, 'k', lw=1)
+# mismatch = np.sign(x-F0)*(duration * np.pi * (x - F0))**2 / 12.0
+ax.plot(frequencies, twoF, "k", lw=1)
 DeltaF = frequencies - F0
-sinc = np.sin(np.pi*DeltaF*duration)/(np.pi*DeltaF*duration)
-A = np.abs((np.max(twoF)-4)*sinc**2 + 4)
-ax.plot(frequencies, A, '-r', lw=1)
-ax.set_ylabel('$\widetilde{2\mathcal{F}}$')
-ax.set_xlabel('Frequency')
+sinc = np.sin(np.pi * DeltaF * duration) / (np.pi * DeltaF * duration)
+A = np.abs((np.max(twoF) - 4) * sinc ** 2 + 4)
+ax.plot(frequencies, A, "-r", lw=1)
+ax.set_ylabel("$\widetilde{2\mathcal{F}}$")
+ax.set_xlabel("Frequency")
 ax.set_xlim(F0s[0], F0s[1])
-dF0 = np.sqrt(12*1)/(np.pi*duration)
-xticks = [F0-10*dF0, F0, F0+10*dF0]
+dF0 = np.sqrt(12 * 1) / (np.pi * duration)
+xticks = [F0 - 10 * dF0, F0, F0 + 10 * dF0]
 ax.set_xticks(xticks)
-xticklabels = ['$f_0 {-} 10\Delta f$', '$f_0$', '$f_0 {+} 10\Delta f$']
+xticklabels = ["$f_0 {-} 10\Delta f$", "$f_0$", "$f_0 {+} 10\Delta f$"]
 ax.set_xticklabels(xticklabels)
 plt.tight_layout()
-fig.savefig('{}/{}_1D.png'.format(search.outdir, search.label), dpi=300)
+fig.savefig("{}/{}_1D.png".format(search.outdir, search.label), dpi=300)
diff --git a/examples/other_examples/sliding_window.py b/examples/other_examples/sliding_window.py
index 96ea6745b4ca5d8249b7ecedc4ad9955bfffbd12..9a2e4b0af8d6a79686ce9a5dd52d5685c202031d 100644
--- a/examples/other_examples/sliding_window.py
+++ b/examples/other_examples/sliding_window.py
@@ -4,7 +4,7 @@ import numpy as np
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
+duration = 100 * 86400
 tend = tstart + duration
 
 # Properties of the signal
@@ -13,23 +13,43 @@ F1 = -1e-10
 F2 = 0
 Alpha = np.radians(83.6292)
 Delta = np.radians(22.0144)
-tref = .5*(tstart+tend)
+tref = 0.5 * (tstart + tend)
 
 depth = 60
 h0 = sqrtSX / depth
-data_label = 'sliding_window'
+data_label = "sliding_window"
 
 data = pyfstat.Writer(
-    label=data_label, outdir='data', tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=data_label,
+    outdir="data",
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 DeltaF0 = 1e-5
 search = pyfstat.FrequencySlidingWindow(
-        label='sliding_window', outdir='data', sftfilepattern='data/*sliding_window*sft',
-        F0s=[F0-DeltaF0, F0+DeltaF0, DeltaF0/100.], F1=F1, F2=0,
-        Alpha=Alpha, Delta=Delta, tref=tref, minStartTime=tstart,
-        maxStartTime=tend, window_size=25*86400, window_delta=1*86400)
+    label="sliding_window",
+    outdir="data",
+    sftfilepattern="data/*sliding_window*sft",
+    F0s=[F0 - DeltaF0, F0 + DeltaF0, DeltaF0 / 100.0],
+    F1=F1,
+    F2=0,
+    Alpha=Alpha,
+    Delta=Delta,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tend,
+    window_size=25 * 86400,
+    window_delta=1 * 86400,
+)
 search.run()
 search.plot_sliding_window()
diff --git a/examples/other_examples/twoF_cumulative.py b/examples/other_examples/twoF_cumulative.py
index 3b14e49b5f3993bbb0b17a26f78206c0206eae7d..1065d3e087be6c0bc59eff297461c3f85d3bad24 100644
--- a/examples/other_examples/twoF_cumulative.py
+++ b/examples/other_examples/twoF_cumulative.py
@@ -4,7 +4,7 @@ import numpy as np
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
+duration = 100 * 86400
 tend = tstart + duration
 
 # Properties of the signal
@@ -13,38 +13,46 @@ F1 = -1e-10
 F2 = 0
 Alpha = np.radians(83.6292)
 Delta = np.radians(22.0144)
-tref = .5*(tstart+tend)
+tref = 0.5 * (tstart + tend)
 
 depth = 100
 h0 = sqrtSX / depth
-data_label = 'twoF_cumulative'
+data_label = "twoF_cumulative"
 
 data = pyfstat.Writer(
-    label=data_label, outdir='data', tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX, detectors='H1,L1')
+    label=data_label,
+    outdir="data",
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+    detectors="H1,L1",
+)
 data.make_data()
 
 # The predicted twoF, given by lalapps_predictFstat can be accessed by
 twoF = data.predict_fstat()
-print('Predicted twoF value: {}\n'.format(twoF))
+print("Predicted twoF value: {}\n".format(twoF))
 
 DeltaF0 = 1e-7
 DeltaF1 = 1e-13
-VF0 = (np.pi * duration * DeltaF0)**2 / 3.0
-VF1 = (np.pi * duration**2 * DeltaF1)**2 * 4/45.
-print('\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n'.format(VF0*VF1, VF0, VF1))
+VF0 = (np.pi * duration * DeltaF0) ** 2 / 3.0
+VF1 = (np.pi * duration ** 2 * DeltaF1) ** 2 * 4 / 45.0
+print("\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n".format(VF0 * VF1, VF0, VF1))
 
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2.},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2.},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta
-               }
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2.0},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 1
 log10beta_min = -1
@@ -52,10 +60,18 @@ nwalkers = 100
 nsteps = [50, 50]
 
 mcmc = pyfstat.MCMCSearch(
-    label='twoF_cumulative', outdir='data',
-    sftfilepattern='data/*'+data_label+'*sft', theta_prior=theta_prior, tref=tref,
-    minStartTime=tstart, maxStartTime=tend, nsteps=nsteps, nwalkers=nwalkers,
-    ntemps=ntemps, log10beta_min=log10beta_min)
+    label="twoF_cumulative",
+    outdir="data",
+    sftfilepattern="data/*" + data_label + "*sft",
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tend,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+)
 mcmc.run()
 mcmc.plot_corner(add_prior=True)
 mcmc.print_summary()
diff --git a/examples/other_examples/using_initialisation.py b/examples/other_examples/using_initialisation.py
index c8d8d61398ab4fea08b668315422647e061679e5..105380d45b7bf6e988b65d4882be467380ff4239 100644
--- a/examples/other_examples/using_initialisation.py
+++ b/examples/other_examples/using_initialisation.py
@@ -4,7 +4,7 @@ import numpy as np
 # Properties of the GW data
 sqrtSX = 1e-23
 tstart = 1000000000
-duration = 100*86400
+duration = 100 * 86400
 tend = tstart + duration
 
 # Properties of the signal
@@ -13,39 +13,46 @@ F1 = -1e-10
 F2 = 0
 Alpha = np.radians(83.6292)
 Delta = np.radians(22.0144)
-tref = .5*(tstart+tend)
+tref = 0.5 * (tstart + tend)
 
 depth = 10
 h0 = sqrtSX / depth
-label = 'using_initialisation'
-outdir = 'data'
+label = "using_initialisation"
+outdir = "data"
 
 data = pyfstat.Writer(
-    label=label, outdir=outdir, tref=tref,
-    tstart=tstart, F0=F0, F1=F1, F2=F2, duration=duration, Alpha=Alpha,
-    Delta=Delta, h0=h0, sqrtSX=sqrtSX)
+    label=label,
+    outdir=outdir,
+    tref=tref,
+    tstart=tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+)
 data.make_data()
 
 # The predicted twoF, given by lalapps_predictFstat can be accessed by
 twoF = data.predict_fstat()
-print('Predicted twoF value: {}\n'.format(twoF))
+print("Predicted twoF value: {}\n".format(twoF))
 
 DeltaF0 = 1e-7
 DeltaF1 = 1e-13
-VF0 = (np.pi * duration * DeltaF0)**2 / 3.0
-VF1 = (np.pi * duration**2 * DeltaF1)**2 * 4/45.
-print('\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n'.format(VF0*VF1, VF0, VF1))
+VF0 = (np.pi * duration * DeltaF0) ** 2 / 3.0
+VF1 = (np.pi * duration ** 2 * DeltaF1) ** 2 * 4 / 45.0
+print("\nV={:1.2e}, VF0={:1.2e}, VF1={:1.2e}\n".format(VF0 * VF1, VF0, VF1))
 
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2.},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2.},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta
-               }
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2.0},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+}
 
 ntemps = 1
 log10beta_min = -1
@@ -53,11 +60,18 @@ nwalkers = 100
 nsteps = [100, 100]
 
 mcmc = pyfstat.MCMCSearch(
-    label=label, outdir=outdir,
-    sftfilepattern='{}/*{}*sft'.format(outdir, label),
-    theta_prior=theta_prior, tref=tref, minStartTime=tstart, maxStartTime=tend,
-    nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
-    log10beta_min=log10beta_min)
+    label=label,
+    outdir=outdir,
+    sftfilepattern="{}/*{}*sft".format(outdir, label),
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=tstart,
+    maxStartTime=tend,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
+    log10beta_min=log10beta_min,
+)
 mcmc.setup_initialisation(100, scatter_val=1e-10)
 mcmc.run()
 mcmc.plot_corner(add_prior=True)
diff --git a/examples/transient_examples/long_transient_search_MCMC.py b/examples/transient_examples/long_transient_search_MCMC.py
index a9255fa63b7b008f70468e86edff0364e11de60a..38731e8e8e9ca1b6b70cb9b93857148e42e6ffd7 100644
--- a/examples/transient_examples/long_transient_search_MCMC.py
+++ b/examples/transient_examples/long_transient_search_MCMC.py
@@ -9,27 +9,26 @@ Alpha = 0.5
 Delta = 1
 
 minStartTime = 1000000000
-maxStartTime = minStartTime + 200*86400
+maxStartTime = minStartTime + 200 * 86400
 Tspan = maxStartTime - minStartTime
 tref = minStartTime
 
 DeltaF0 = 6e-7
 DeltaF1 = 1e-13
 
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2.},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2.},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta,
-               'transient_tstart': minStartTime,
-               'transient_duration': {'type': 'halfnorm',
-                                      'loc': 0.001*Tspan,
-                                      'scale': 0.5*Tspan}
-               }
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2.0},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+    "transient_tstart": minStartTime,
+    "transient_duration": {
+        "type": "halfnorm",
+        "loc": 0.001 * Tspan,
+        "scale": 0.5 * Tspan,
+    },
+}
 
 ntemps = 2
 log10beta_min = -1
@@ -37,12 +36,19 @@ nwalkers = 100
 nsteps = [100, 100]
 
 mcmc = pyfstat.MCMCTransientSearch(
-    label='transient_search', outdir='data_l',
-    sftfilepattern='data_l/*simulated_transient_signal*sft',
-    theta_prior=theta_prior, tref=tref, minStartTime=minStartTime,
-    maxStartTime=maxStartTime, nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
+    label="transient_search",
+    outdir="data_l",
+    sftfilepattern="data_l/*simulated_transient_signal*sft",
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=minStartTime,
+    maxStartTime=maxStartTime,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
     log10beta_min=log10beta_min,
-    transientWindowType='rect')
+    transientWindowType="rect",
+)
 mcmc.run()
 mcmc.plot_corner(label_offset=0.7)
 mcmc.print_summary()
diff --git a/examples/transient_examples/long_transient_search_make_simulated_data.py b/examples/transient_examples/long_transient_search_make_simulated_data.py
index de68fccdb40e16063bd870c0b69829b667dc49c3..d86ef73a0b97301719e089fb9ff6933b0dc7e816 100644
--- a/examples/transient_examples/long_transient_search_make_simulated_data.py
+++ b/examples/transient_examples/long_transient_search_make_simulated_data.py
@@ -9,18 +9,30 @@ Alpha = 0.5
 Delta = 1
 
 minStartTime = 1000000000
-maxStartTime = minStartTime + 200*86400
+maxStartTime = minStartTime + 200 * 86400
 
-transient_tstart = minStartTime + 50*86400
-transient_duration = 100*86400
+transient_tstart = minStartTime + 50 * 86400
+transient_duration = 100 * 86400
 tref = minStartTime
 
 h0 = 1e-23
 sqrtSX = 1e-22
 
 transient = pyfstat.Writer(
-    label='simulated_transient_signal', outdir='data_l', tref=tref,
-    tstart=transient_tstart, F0=F0, F1=F1, F2=F2, duration=transient_duration,
-    Alpha=Alpha, Delta=Delta, h0=h0, sqrtSX=sqrtSX, minStartTime=minStartTime,
-    maxStartTime=maxStartTime, transientWindowType='rect')
+    label="simulated_transient_signal",
+    outdir="data_l",
+    tref=tref,
+    tstart=transient_tstart,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    duration=transient_duration,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    sqrtSX=sqrtSX,
+    minStartTime=minStartTime,
+    maxStartTime=maxStartTime,
+    transientWindowType="rect",
+)
 transient.make_data()
diff --git a/examples/transient_examples/short_transient_search_MCMC.py b/examples/transient_examples/short_transient_search_MCMC.py
index 1b8f504a074f20372c9142d6d220f0650b0e498b..6d7fcab465d1a08895b89698db518d6ba3547df3 100644
--- a/examples/transient_examples/short_transient_search_MCMC.py
+++ b/examples/transient_examples/short_transient_search_MCMC.py
@@ -9,7 +9,7 @@ Alpha = 0.5
 Delta = 1
 
 minStartTime = 1000000000
-maxStartTime = minStartTime + 2*86400
+maxStartTime = minStartTime + 2 * 86400
 Tspan = maxStartTime - minStartTime
 tref = minStartTime
 
@@ -18,22 +18,23 @@ Tsft = 1800
 DeltaF0 = 1e-2
 DeltaF1 = 1e-9
 
-theta_prior = {'F0': {'type': 'unif',
-                      'lower': F0-DeltaF0/2.,
-                      'upper': F0+DeltaF0/2.},
-               'F1': {'type': 'unif',
-                      'lower': F1-DeltaF1/2.,
-                      'upper': F1+DeltaF1/2.},
-               'F2': F2,
-               'Alpha': Alpha,
-               'Delta': Delta,
-               'transient_tstart': {'type': 'unif',
-                                    'lower': minStartTime,
-                                    'upper': maxStartTime-2*Tsft},
-               'transient_duration': {'type': 'unif',
-                                      'lower': 2*Tsft,
-                                      'upper': Tspan-2*Tsft}
-               }
+theta_prior = {
+    "F0": {"type": "unif", "lower": F0 - DeltaF0 / 2.0, "upper": F0 + DeltaF0 / 2.0},
+    "F1": {"type": "unif", "lower": F1 - DeltaF1 / 2.0, "upper": F1 + DeltaF1 / 2.0},
+    "F2": F2,
+    "Alpha": Alpha,
+    "Delta": Delta,
+    "transient_tstart": {
+        "type": "unif",
+        "lower": minStartTime,
+        "upper": maxStartTime - 2 * Tsft,
+    },
+    "transient_duration": {
+        "type": "unif",
+        "lower": 2 * Tsft,
+        "upper": Tspan - 2 * Tsft,
+    },
+}
 
 ntemps = 2
 log10beta_min = -1
@@ -41,12 +42,19 @@ nwalkers = 100
 nsteps = [100, 100]
 
 mcmc = pyfstat.MCMCTransientSearch(
-    label='transient_search', outdir='data_s',
-    sftfilepattern='data_s/*simulated_transient_signal*sft',
-    theta_prior=theta_prior, tref=tref, minStartTime=minStartTime,
-    maxStartTime=maxStartTime, nsteps=nsteps, nwalkers=nwalkers, ntemps=ntemps,
+    label="transient_search",
+    outdir="data_s",
+    sftfilepattern="data_s/*simulated_transient_signal*sft",
+    theta_prior=theta_prior,
+    tref=tref,
+    minStartTime=minStartTime,
+    maxStartTime=maxStartTime,
+    nsteps=nsteps,
+    nwalkers=nwalkers,
+    ntemps=ntemps,
     log10beta_min=log10beta_min,
-    transientWindowType='rect')
+    transientWindowType="rect",
+)
 mcmc.run()
 mcmc.plot_corner(label_offset=0.7)
 mcmc.print_summary()
diff --git a/examples/transient_examples/short_transient_search_gridded.py b/examples/transient_examples/short_transient_search_gridded.py
index 5011b8052ef325bc7b3590195b9112a335b1cdaf..6e3c888caed26d6b890b562288606d7ccebe0495 100644
--- a/examples/transient_examples/short_transient_search_gridded.py
+++ b/examples/transient_examples/short_transient_search_gridded.py
@@ -5,7 +5,7 @@ import os
 import numpy as np
 import matplotlib.pyplot as plt
 
-datadir = 'data_s'
+datadir = "data_s"
 
 F0 = 30.0
 F1 = -1e-10
@@ -14,46 +14,62 @@ Alpha = 0.5
 Delta = 1
 
 minStartTime = 1000000000
-maxStartTime = minStartTime + 2*86400
+maxStartTime = minStartTime + 2 * 86400
 Tspan = maxStartTime - minStartTime
 tref = minStartTime
 
 Tsft = 1800
 
 m = 0.001
-dF0 = np.sqrt(12*m)/(np.pi*Tspan)
-DeltaF0 = 100*dF0
-F0s = [F0-DeltaF0/2., F0+DeltaF0/2., dF0]
+dF0 = np.sqrt(12 * m) / (np.pi * Tspan)
+DeltaF0 = 100 * dF0
+F0s = [F0 - DeltaF0 / 2.0, F0 + DeltaF0 / 2.0, dF0]
 F1s = [F1]
 F2s = [F2]
 Alphas = [Alpha]
 Deltas = [Delta]
 
-print('Standard CW search:')
+print("Standard CW search:")
 search1 = pyfstat.GridSearch(
-    label='CW', outdir=datadir,
-    sftfilepattern=os.path.join(datadir,'*simulated_transient_signal*sft'),
-    F0s=F0s, F1s=F1s, F2s=F2s, Alphas=Alphas, Deltas=Deltas, tref=tref,
-    minStartTime=minStartTime, maxStartTime=maxStartTime,
-    BSGL=False)
+    label="CW",
+    outdir=datadir,
+    sftfilepattern=os.path.join(datadir, "*simulated_transient_signal*sft"),
+    F0s=F0s,
+    F1s=F1s,
+    F2s=F2s,
+    Alphas=Alphas,
+    Deltas=Deltas,
+    tref=tref,
+    minStartTime=minStartTime,
+    maxStartTime=maxStartTime,
+    BSGL=False,
+)
 search1.run()
 search1.print_max_twoF()
 
-search1.plot_1D(xkey='F0',
-               xlabel='freq [Hz]', ylabel='$2\mathcal{F}$')
+search1.plot_1D(xkey="F0", xlabel="freq [Hz]", ylabel="$2\mathcal{F}$")
 
-print('with t0,tau bands:')
+print("with t0,tau bands:")
 search2 = pyfstat.TransientGridSearch(
-    label='tCW', outdir=datadir,
-    sftfilepattern=os.path.join(datadir,'*simulated_transient_signal*sft'),
-    F0s=F0s, F1s=F1s, F2s=F2s, Alphas=Alphas, Deltas=Deltas, tref=tref,
-    minStartTime=minStartTime, maxStartTime=maxStartTime,
-    transientWindowType='rect', t0Band=Tspan-2*Tsft, tauBand=Tspan,
+    label="tCW",
+    outdir=datadir,
+    sftfilepattern=os.path.join(datadir, "*simulated_transient_signal*sft"),
+    F0s=F0s,
+    F1s=F1s,
+    F2s=F2s,
+    Alphas=Alphas,
+    Deltas=Deltas,
+    tref=tref,
+    minStartTime=minStartTime,
+    maxStartTime=maxStartTime,
+    transientWindowType="rect",
+    t0Band=Tspan - 2 * Tsft,
+    tauBand=Tspan,
     BSGL=False,
     outputTransientFstatMap=True,
-    tCWFstatMapVersion='lal')
+    tCWFstatMapVersion="lal",
+)
 search2.run()
 search2.print_max_twoF()
 
-search2.plot_1D(xkey='F0',
-               xlabel='freq [Hz]', ylabel='$2\mathcal{F}$')
+search2.plot_1D(xkey="F0", xlabel="freq [Hz]", ylabel="$2\mathcal{F}$")
diff --git a/examples/transient_examples/short_transient_search_make_simulated_data.py b/examples/transient_examples/short_transient_search_make_simulated_data.py
index 320b4e4eb07dd1fce589156f0e1cfa96ea0b78a6..19845b2b54c995b4e4a7be38ea2560c0e29a2cee 100644
--- a/examples/transient_examples/short_transient_search_make_simulated_data.py
+++ b/examples/transient_examples/short_transient_search_make_simulated_data.py
@@ -9,23 +9,35 @@ Alpha = 0.5
 Delta = 1
 
 minStartTime = 1000000000
-maxStartTime = minStartTime + 2*86400
+maxStartTime = minStartTime + 2 * 86400
 
-transient_tstart = minStartTime + 0.5*86400
-transient_duration = 1*86400
+transient_tstart = minStartTime + 0.5 * 86400
+transient_duration = 1 * 86400
 tref = minStartTime
 
 h0 = 1e-23
 sqrtSX = 1e-22
-detectors = 'H1,L1'
+detectors = "H1,L1"
 
 Tsft = 1800
 
 transient = pyfstat.Writer(
-    label='simulated_transient_signal', outdir='data_s',
-    tref=tref, tstart=transient_tstart, duration=transient_duration,
-    F0=F0, F1=F1, F2=F2, Alpha=Alpha, Delta=Delta, h0=h0,
-    detectors=detectors,sqrtSX=sqrtSX,
-    minStartTime=minStartTime, maxStartTime=maxStartTime,
-    transientWindowType='rect', Tsft=Tsft)
+    label="simulated_transient_signal",
+    outdir="data_s",
+    tref=tref,
+    tstart=transient_tstart,
+    duration=transient_duration,
+    F0=F0,
+    F1=F1,
+    F2=F2,
+    Alpha=Alpha,
+    Delta=Delta,
+    h0=h0,
+    detectors=detectors,
+    sqrtSX=sqrtSX,
+    minStartTime=minStartTime,
+    maxStartTime=maxStartTime,
+    transientWindowType="rect",
+    Tsft=Tsft,
+)
 transient.make_data()
diff --git a/pyfstat/__init__.py b/pyfstat/__init__.py
index 9c589d6c2a1c75bdb106f69fc5b069593455fb10..e81fdb0ed13875a1d73e9e1bf38413692d128e3c 100644
--- a/pyfstat/__init__.py
+++ b/pyfstat/__init__.py
@@ -1,6 +1,28 @@
-
-
-from .core import BaseSearchClass, ComputeFstat, SemiCoherentSearch, SemiCoherentGlitchSearch
-from .make_sfts import Writer, GlitchWriter, FrequencyModulatedArtifactWriter, FrequencyAmplitudeModulatedArtifactWriter
-from .mcmc_based_searches import MCMCSearch, MCMCGlitchSearch, MCMCSemiCoherentSearch, MCMCFollowUpSearch, MCMCTransientSearch
-from .grid_based_searches import GridSearch, GridUniformPriorSearch, GridGlitchSearch, FrequencySlidingWindow, DMoff_NO_SPIN, SliceGridSearch, TransientGridSearch
+from .core import (
+    BaseSearchClass,
+    ComputeFstat,
+    SemiCoherentSearch,
+    SemiCoherentGlitchSearch,
+)
+from .make_sfts import (
+    Writer,
+    GlitchWriter,
+    FrequencyModulatedArtifactWriter,
+    FrequencyAmplitudeModulatedArtifactWriter,
+)
+from .mcmc_based_searches import (
+    MCMCSearch,
+    MCMCGlitchSearch,
+    MCMCSemiCoherentSearch,
+    MCMCFollowUpSearch,
+    MCMCTransientSearch,
+)
+from .grid_based_searches import (
+    GridSearch,
+    GridUniformPriorSearch,
+    GridGlitchSearch,
+    FrequencySlidingWindow,
+    DMoff_NO_SPIN,
+    SliceGridSearch,
+    TransientGridSearch,
+)
diff --git a/pyfstat/core.py b/pyfstat/core.py
index 2b832392c8e26100d95e2f8ab9b9db7904404103..d2fb2a1e6e52369cb6de5d9cd08ff33c45681557 100755
--- a/pyfstat/core.py
+++ b/pyfstat/core.py
@@ -16,18 +16,21 @@ import pyfstat.helper_functions as helper_functions
 import pyfstat.tcw_fstat_map_funcs as tcw
 
 # workaround for matplotlib on X-less remote logins
-if 'DISPLAY' in os.environ:
+if "DISPLAY" in os.environ:
     import matplotlib.pyplot as plt
 else:
-    logging.info('No $DISPLAY environment variable found, so importing \
-                  matplotlib.pyplot with non-interactive "Agg" backend.')
+    logging.info(
+        'No $DISPLAY environment variable found, so importing \
+                  matplotlib.pyplot with non-interactive "Agg" backend.'
+    )
     import matplotlib
-    matplotlib.use('Agg')
+
+    matplotlib.use("Agg")
     import matplotlib.pyplot as plt
 
 helper_functions.set_up_matplotlib_defaults()
 args, tqdm = helper_functions.set_up_command_line_arguments()
-detector_colors = {'h1': 'C0', 'l1': 'C1'}
+detector_colors = {"h1": "C0", "l1": "C1"}
 
 
 class Bunch(object):
@@ -49,12 +52,20 @@ class Bunch(object):
     True
 
     """
+
     def __init__(self, dictionary):
         self.__dict__.update(dictionary)
 
 
-def read_par(filename=None, label=None, outdir=None, suffix='par',
-             return_type='dict', comments=['%', '#'], raise_error=False):
+def read_par(
+    filename=None,
+    label=None,
+    outdir=None,
+    suffix="par",
+    return_type="dict",
+    comments=["%", "#"],
+    raise_error=False,
+):
     """ Read in a .par or .loudest file, returns a dict or Bunch of the data
 
     Parameters
@@ -83,18 +94,18 @@ def read_par(filename=None, label=None, outdir=None, suffix='par',
 
     """
     if filename is None:
-        filename = '{}/{}.{}'.format(outdir, label, suffix)
+        filename = "{}/{}.{}".format(outdir, label, suffix)
     if os.path.isfile(filename) is False:
         raise ValueError("No file {} found".format(filename))
     d = {}
-    with open(filename, 'r') as f:
+    with open(filename, "r") as f:
         d = _get_dictionary_from_lines(f, comments, raise_error)
-    if return_type in ['bunch', 'Bunch']:
+    if return_type in ["bunch", "Bunch"]:
         return Bunch(d)
-    elif return_type in ['dict', 'dictionary']:
+    elif return_type in ["dict", "dictionary"]:
         return d
     else:
-        raise ValueError('return_type {} not understood'.format(return_type))
+        raise ValueError("return_type {} not understood".format(return_type))
 
 
 def _get_dictionary_from_lines(lines, comments, raise_error):
@@ -116,28 +127,40 @@ def _get_dictionary_from_lines(lines, comments, raise_error):
     """
     d = {}
     for line in lines:
-        if line[0] not in comments and len(line.split('=')) == 2:
+        if line[0] not in comments and len(line.split("=")) == 2:
             try:
-                key, val = line.rstrip('\n').split('=')
+                key, val = line.rstrip("\n").split("=")
                 key = key.strip()
                 val = val.strip()
                 if (val[0] in ["'", '"']) and (val[-1] in ["'", '"']):
                     d[key] = val.lstrip('"').lstrip("'").rstrip('"').rstrip("'")
                 else:
                     try:
-                        d[key] = np.float64(eval(val.rstrip('; ')))
+                        d[key] = np.float64(eval(val.rstrip("; ")))
                     except NameError:
-                        d[key] = val.rstrip('; ')
+                        d[key] = val.rstrip("; ")
             except SyntaxError:
                 if raise_error:
-                    raise IOError('Line {} not understood'.format(line))
+                    raise IOError("Line {} not understood".format(line))
                 pass
     return d
 
 
-def predict_fstat(h0, cosi, psi, Alpha, Delta, Freq, sftfilepattern,
-                  minStartTime, maxStartTime, IFOs=None, assumeSqrtSX=None,
-                  tempory_filename='fs.tmp', **kwargs):
+def predict_fstat(
+    h0,
+    cosi,
+    psi,
+    Alpha,
+    Delta,
+    Freq,
+    sftfilepattern,
+    minStartTime,
+    maxStartTime,
+    IFOs=None,
+    assumeSqrtSX=None,
+    tempory_filename="fs.tmp",
+    **kwargs
+):
     """ Wrapper to lalapps_PredictFstat
 
     Parameters
@@ -171,7 +194,7 @@ def predict_fstat(h0, cosi, psi, Alpha, Delta, Freq, sftfilepattern,
     cl_pfs.append("--DataFiles='{}'".format(sftfilepattern))
     if assumeSqrtSX:
         cl_pfs.append("--assumeSqrtSX={}".format(assumeSqrtSX))
-    #if IFOs:
+    # if IFOs:
     #    cl_pfs.append("--IFOs={}".format(IFOs))
 
     cl_pfs.append("--minStartTime={}".format(int(minStartTime)))
@@ -182,7 +205,7 @@ def predict_fstat(h0, cosi, psi, Alpha, Delta, Freq, sftfilepattern,
     helper_functions.run_commandline(cl_pfs)
     d = read_par(filename=tempory_filename)
     os.remove(tempory_filename)
-    return float(d['twoF_expected']), float(d['twoF_sigma'])
+    return float(d["twoF_expected"]), float(d["twoF_sigma"])
 
 
 class BaseSearchClass(object):
@@ -190,12 +213,14 @@ class BaseSearchClass(object):
 
     def _add_log_file(self):
         """ Log output to a file, requires class to have outdir and label """
-        logfilename = '{}/{}.log'.format(self.outdir, self.label)
+        logfilename = "{}/{}.log".format(self.outdir, self.label)
         fh = logging.FileHandler(logfilename)
         fh.setLevel(logging.INFO)
-        fh.setFormatter(logging.Formatter(
-            '%(asctime)s %(levelname)-8s: %(message)s',
-            datefmt='%y-%m-%d %H:%M'))
+        fh.setFormatter(
+            logging.Formatter(
+                "%(asctime)s %(levelname)-8s: %(message)s", datefmt="%y-%m-%d %H:%M"
+            )
+        )
         logging.getLogger().addHandler(fh)
 
     def _shift_matrix(self, n, dT):
@@ -224,9 +249,9 @@ class BaseSearchClass(object):
                     m[i, j] = 0.0
                 else:
                     if i == 0:
-                        m[i, j] = 2*np.pi*float(dT)**(j-i) / factorial(j-i)
+                        m[i, j] = 2 * np.pi * float(dT) ** (j - i) / factorial(j - i)
                     else:
-                        m[i, j] = float(dT)**(j-i) / factorial(j-i)
+                        m[i, j] = float(dT) ** (j - i) / factorial(j - i)
         return m
 
     def _shift_coefficients(self, theta, dT):
@@ -278,30 +303,38 @@ class BaseSearchClass(object):
         for i, dt in enumerate(delta_thetas):
             if i < theta0_idx:
                 pre_theta_at_ith_glitch = self._shift_coefficients(
-                    thetas[0], tbounds[i+1] - self.tref)
+                    thetas[0], tbounds[i + 1] - self.tref
+                )
                 post_theta_at_ith_glitch = pre_theta_at_ith_glitch - dt
-                thetas.insert(0, self._shift_coefficients(
-                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
+                thetas.insert(
+                    0,
+                    self._shift_coefficients(
+                        post_theta_at_ith_glitch, self.tref - tbounds[i + 1]
+                    ),
+                )
 
             elif i >= theta0_idx:
                 pre_theta_at_ith_glitch = self._shift_coefficients(
-                    thetas[i], tbounds[i+1] - self.tref)
+                    thetas[i], tbounds[i + 1] - self.tref
+                )
                 post_theta_at_ith_glitch = pre_theta_at_ith_glitch + dt
-                thetas.append(self._shift_coefficients(
-                    post_theta_at_ith_glitch, self.tref - tbounds[i+1]))
+                thetas.append(
+                    self._shift_coefficients(
+                        post_theta_at_ith_glitch, self.tref - tbounds[i + 1]
+                    )
+                )
         self.thetas_at_tref = thetas
         return thetas
 
     def _get_list_of_matching_sfts(self):
         """ Returns a list of sfts matching the attribute sftfilepattern """
-        sftfilepatternlist = np.atleast_1d(self.sftfilepattern.split(';'))
+        sftfilepatternlist = np.atleast_1d(self.sftfilepattern.split(";"))
         matches = [glob.glob(p) for p in sftfilepatternlist]
         matches = [item for sublist in matches for item in sublist]
         if len(matches) > 0:
             return matches
         else:
-            raise IOError('No sfts found matching {}'.format(
-                self.sftfilepattern))
+            raise IOError("No sfts found matching {}".format(self.sftfilepattern))
 
     def set_ephemeris_files(self, earth_ephem=None, sun_ephem=None):
         """ Set the ephemeris files to use for the Earth and Sun
@@ -316,8 +349,7 @@ class BaseSearchClass(object):
 
         """
 
-        earth_ephem_default, sun_ephem_default = (
-                helper_functions.get_ephemeris_files())
+        earth_ephem_default, sun_ephem_default = helper_functions.get_ephemeris_files()
 
         if earth_ephem is None:
             self.earth_ephem = earth_ephem_default
@@ -329,15 +361,30 @@ class ComputeFstat(BaseSearchClass):
     """ Base class providing interface to `lalpulsar.ComputeFstat` """
 
     @helper_functions.initializer
-    def __init__(self, tref, sftfilepattern=None, minStartTime=None,
-                 maxStartTime=None, binary=False, BSGL=False,
-                 transientWindowType=None, t0Band=None, tauBand=None,
-                 tauMin=None,
-                 dt0=None, dtau=None,
-                 detectors=None, minCoverFreq=None, maxCoverFreq=None,
-                 injectSources=None, injectSqrtSX=None, assumeSqrtSX=None,
-                 SSBprec=None,
-                 tCWFstatMapVersion='lal', cudaDeviceName=None):
+    def __init__(
+        self,
+        tref,
+        sftfilepattern=None,
+        minStartTime=None,
+        maxStartTime=None,
+        binary=False,
+        BSGL=False,
+        transientWindowType=None,
+        t0Band=None,
+        tauBand=None,
+        tauMin=None,
+        dt0=None,
+        dtau=None,
+        detectors=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        injectSources=None,
+        injectSqrtSX=None,
+        assumeSqrtSX=None,
+        SSBprec=None,
+        tCWFstatMapVersion="lal",
+        cudaDeviceName=None,
+    ):
         """
         Parameters
         ----------
@@ -409,71 +456,72 @@ class ComputeFstat(BaseSearchClass):
         SFTCatalog: lalpulsar.SFTCatalog
 
         """
-        if hasattr(self, 'SFTCatalog'):
+        if hasattr(self, "SFTCatalog"):
             return
         if self.sftfilepattern is None:
-            for k in ['minStartTime', 'maxStartTime', 'detectors']:
+            for k in ["minStartTime", "maxStartTime", "detectors"]:
                 if getattr(self, k) is None:
-                    raise ValueError('You must provide "{}" to injectSources'
-                                     .format(k))
-            C1 = getattr(self, 'injectSources', None) is None
-            C2 = getattr(self, 'injectSqrtSX', None) is None
+                    raise ValueError('You must provide "{}" to injectSources'.format(k))
+            C1 = getattr(self, "injectSources", None) is None
+            C2 = getattr(self, "injectSqrtSX", None) is None
             if C1 and C2:
-                raise ValueError('You must specify either one of injectSources'
-                                 ' or injectSqrtSX')
+                raise ValueError(
+                    "You must specify either one of injectSources" " or injectSqrtSX"
+                )
             SFTCatalog = lalpulsar.SFTCatalog()
             Tsft = 1800
             Toverlap = 0
             Tspan = self.maxStartTime - self.minStartTime
-            detNames = lal.CreateStringVector(
-                *[d for d in self.detectors.split(',')])
+            detNames = lal.CreateStringVector(*[d for d in self.detectors.split(",")])
             multiTimestamps = lalpulsar.MakeMultiTimestamps(
-                self.minStartTime, Tspan, Tsft, Toverlap, detNames.length)
+                self.minStartTime, Tspan, Tsft, Toverlap, detNames.length
+            )
             SFTCatalog = lalpulsar.MultiAddToFakeSFTCatalog(
-                SFTCatalog, detNames, multiTimestamps)
+                SFTCatalog, detNames, multiTimestamps
+            )
             return SFTCatalog
 
-        logging.info('Initialising SFTCatalog')
+        logging.info("Initialising SFTCatalog")
         constraints = lalpulsar.SFTConstraints()
         if self.detectors:
-            if ',' in self.detectors:
-                logging.warning('Multiple detector selection not available,'
-                                ' using all available data')
+            if "," in self.detectors:
+                logging.warning(
+                    "Multiple detector selection not available,"
+                    " using all available data"
+                )
             else:
                 constraints.detector = self.detectors
         if self.minStartTime:
             constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
         if self.maxStartTime:
             constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)
-        logging.info('Loading data matching pattern {}'.format(
-                     self.sftfilepattern))
+        logging.info("Loading data matching pattern {}".format(self.sftfilepattern))
         SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepattern, constraints)
 
         SFT_timestamps = [d.header.epoch for d in SFTCatalog.data]
         self.SFT_timestamps = [float(s) for s in SFT_timestamps]
         if len(SFT_timestamps) == 0:
-            raise ValueError('Failed to load any data')
+            raise ValueError("Failed to load any data")
         if args.quite is False and args.no_interactive is False:
             try:
                 from bashplotlib.histogram import plot_hist
-                print('Data timestamps histogram:')
+
+                print("Data timestamps histogram:")
                 plot_hist(SFT_timestamps, height=5, bincount=50)
             except ImportError:
                 pass
 
-        cl_tconv1 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[0]))
-        output = helper_functions.run_commandline(cl_tconv1,
-                                                  log_level=logging.DEBUG)
-        tconvert1 = output.rstrip('\n')
-        cl_tconv2 = 'lalapps_tconvert {}'.format(int(SFT_timestamps[-1]))
-        output = helper_functions.run_commandline(cl_tconv2,
-                                                  log_level=logging.DEBUG)
-        tconvert2 = output.rstrip('\n')
-        logging.info('Data spans from {} ({}) to {} ({})'.format(
-            int(SFT_timestamps[0]),
-            tconvert1,
-            int(SFT_timestamps[-1]),
-            tconvert2))
+        cl_tconv1 = "lalapps_tconvert {}".format(int(SFT_timestamps[0]))
+        output = helper_functions.run_commandline(cl_tconv1, log_level=logging.DEBUG)
+        tconvert1 = output.rstrip("\n")
+        cl_tconv2 = "lalapps_tconvert {}".format(int(SFT_timestamps[-1]))
+        output = helper_functions.run_commandline(cl_tconv2, log_level=logging.DEBUG)
+        tconvert2 = output.rstrip("\n")
+        logging.info(
+            "Data spans from {} ({}) to {} ({})".format(
+                int(SFT_timestamps[0]), tconvert1, int(SFT_timestamps[-1]), tconvert2
+            )
+        )
 
         if self.minStartTime is None:
             self.minStartTime = int(SFT_timestamps[0])
@@ -483,9 +531,12 @@ class ComputeFstat(BaseSearchClass):
         detector_names = list(set([d.header.name for d in SFTCatalog.data]))
         self.detector_names = detector_names
         if len(detector_names) == 0:
-            raise ValueError('No data loaded.')
-        logging.info('Loaded {} data files from detectors {}'.format(
-            len(SFT_timestamps), detector_names))
+            raise ValueError("No data loaded.")
+        logging.info(
+            "Loaded {} data files from detectors {}".format(
+                len(SFT_timestamps), detector_names
+            )
+        )
 
         return SFTCatalog
 
@@ -494,10 +545,10 @@ class ComputeFstat(BaseSearchClass):
 
         SFTCatalog = self._get_SFTCatalog()
 
-        logging.info('Initialising ephems')
+        logging.info("Initialising ephems")
         ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
 
-        logging.info('Initialising FstatInput')
+        logging.info("Initialising FstatInput")
         dFreq = 0
         if self.transientWindowType:
             self.whatToCompute = lalpulsar.FSTATQ_ATOMS_PER_DET
@@ -507,84 +558,84 @@ class ComputeFstat(BaseSearchClass):
         FstatOAs = lalpulsar.FstatOptionalArgs()
         FstatOAs.randSeed = lalpulsar.FstatOptionalArgsDefaults.randSeed
         if self.SSBprec:
-            logging.info('Using SSBprec={}'.format(self.SSBprec))
+            logging.info("Using SSBprec={}".format(self.SSBprec))
             FstatOAs.SSBprec = self.SSBprec
         else:
             FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec
         FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms
-        FstatOAs.runningMedianWindow = lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
+        FstatOAs.runningMedianWindow = (
+            lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
+        )
         FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod
         if self.assumeSqrtSX is None:
             FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX
         else:
             mnf = lalpulsar.MultiNoiseFloor()
             assumeSqrtSX = np.atleast_1d(self.assumeSqrtSX)
-            mnf.sqrtSn[:len(assumeSqrtSX)] = assumeSqrtSX
+            mnf.sqrtSn[: len(assumeSqrtSX)] = assumeSqrtSX
             mnf.length = len(assumeSqrtSX)
             FstatOAs.assumeSqrtSX = mnf
         FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput
         FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming
 
-        if hasattr(self, 'injectSources') and type(self.injectSources) == dict:
-            logging.info('Injecting source with params: {}'.format(
-                self.injectSources))
+        if hasattr(self, "injectSources") and type(self.injectSources) == dict:
+            logging.info("Injecting source with params: {}".format(self.injectSources))
             PPV = lalpulsar.CreatePulsarParamsVector(1)
             PP = PPV.data[0]
-            h0 = self.injectSources['h0']
-            cosi = self.injectSources['cosi']
-            use_aPlus = ('aPlus' in dir(PP.Amp))
+            h0 = self.injectSources["h0"]
+            cosi = self.injectSources["cosi"]
+            use_aPlus = "aPlus" in dir(PP.Amp)
             print("use_aPlus = {}".format(use_aPlus))
             if use_aPlus:  # lalsuite interface changed in aff93c45
-                PP.Amp.aPlus = 0.5 * h0 * (1.0 + cosi**2)
+                PP.Amp.aPlus = 0.5 * h0 * (1.0 + cosi ** 2)
                 PP.Amp.aCross = h0 * cosi
             else:
                 PP.Amp.h0 = h0
                 PP.Amp.cosi = cosi
 
-            PP.Amp.phi0 = self.injectSources['phi0']
-            PP.Amp.psi = self.injectSources['psi']
-            PP.Doppler.Alpha = self.injectSources['Alpha']
-            PP.Doppler.Delta = self.injectSources['Delta']
-            if 'fkdot' in self.injectSources:
-                PP.Doppler.fkdot = np.array(self.injectSources['fkdot'])
+            PP.Amp.phi0 = self.injectSources["phi0"]
+            PP.Amp.psi = self.injectSources["psi"]
+            PP.Doppler.Alpha = self.injectSources["Alpha"]
+            PP.Doppler.Delta = self.injectSources["Delta"]
+            if "fkdot" in self.injectSources:
+                PP.Doppler.fkdot = np.array(self.injectSources["fkdot"])
             else:
                 PP.Doppler.fkdot = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
-                for i, key in enumerate(['F0', 'F1', 'F2']):
+                for i, key in enumerate(["F0", "F1", "F2"]):
                     PP.Doppler.fkdot[i] = self.injectSources[key]
             PP.Doppler.refTime = self.tref
-            if 't0' not in self.injectSources:
+            if "t0" not in self.injectSources:
                 PP.Transient.type = lalpulsar.TRANSIENT_NONE
             FstatOAs.injectSources = PPV
-        elif hasattr(self, 'injectSources') and type(self.injectSources) == str:
-            logging.info('Injecting source from param file: {}'.format(
-                self.injectSources))
+        elif hasattr(self, "injectSources") and type(self.injectSources) == str:
+            logging.info(
+                "Injecting source from param file: {}".format(self.injectSources)
+            )
             PPV = lalpulsar.PulsarParamsFromFile(self.injectSources, self.tref)
             FstatOAs.injectSources = PPV
         else:
             FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
-        if hasattr(self, 'injectSqrtSX') and self.injectSqrtSX is not None:
-            raise ValueError('injectSqrtSX not implemented')
+        if hasattr(self, "injectSqrtSX") and self.injectSqrtSX is not None:
+            raise ValueError("injectSqrtSX not implemented")
         else:
             FstatOAs.InjectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
         if self.minCoverFreq is None or self.maxCoverFreq is None:
             fAs = [d.header.f0 for d in SFTCatalog.data]
-            fBs = [d.header.f0 + (d.numBins-1)*d.header.deltaF
-                   for d in SFTCatalog.data]
+            fBs = [
+                d.header.f0 + (d.numBins - 1) * d.header.deltaF for d in SFTCatalog.data
+            ]
             self.minCoverFreq = np.min(fAs) + 0.5
             self.maxCoverFreq = np.max(fBs) - 0.5
-            logging.info('Min/max cover freqs not provided, using '
-                         '{} and {}, est. from SFTs'.format(
-                             self.minCoverFreq, self.maxCoverFreq))
-
-        self.FstatInput = lalpulsar.CreateFstatInput(SFTCatalog,
-                                                     self.minCoverFreq,
-                                                     self.maxCoverFreq,
-                                                     dFreq,
-                                                     ephems,
-                                                     FstatOAs
-                                                     )
-
-        logging.info('Initialising PulsarDoplerParams')
+            logging.info(
+                "Min/max cover freqs not provided, using "
+                "{} and {}, est. from SFTs".format(self.minCoverFreq, self.maxCoverFreq)
+            )
+
+        self.FstatInput = lalpulsar.CreateFstatInput(
+            SFTCatalog, self.minCoverFreq, self.maxCoverFreq, dFreq, ephems, FstatOAs
+        )
+
+        logging.info("Initialising PulsarDoplerParams")
         PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
         PulsarDopplerParams.refTime = self.tref
         PulsarDopplerParams.Alpha = 1
@@ -592,54 +643,54 @@ class ComputeFstat(BaseSearchClass):
         PulsarDopplerParams.fkdot = np.array([0, 0, 0, 0, 0, 0, 0])
         self.PulsarDopplerParams = PulsarDopplerParams
 
-        logging.info('Initialising FstatResults')
+        logging.info("Initialising FstatResults")
         self.FstatResults = lalpulsar.FstatResults()
 
         if self.BSGL:
             if len(self.detector_names) < 2:
                 raise ValueError("Can't use BSGL with single detectors data")
             else:
-                logging.info('Initialising BSGL')
+                logging.info("Initialising BSGL")
 
             # Tuning parameters - to be reviewed
             numDetectors = 2
-            if hasattr(self, 'nsegs'):
+            if hasattr(self, "nsegs"):
                 p_val_threshold = 1e-6
                 Fstar0s = np.linspace(0, 1000, 10000)
-                p_vals = scipy.special.gammaincc(2*self.nsegs, Fstar0s)
+                p_vals = scipy.special.gammaincc(2 * self.nsegs, Fstar0s)
                 Fstar0 = Fstar0s[np.argmin(np.abs(p_vals - p_val_threshold))]
                 if Fstar0 == Fstar0s[-1]:
-                    raise ValueError('Max Fstar0 exceeded')
+                    raise ValueError("Max Fstar0 exceeded")
             else:
-                Fstar0 = 15.
-            logging.info('Using Fstar0 of {:1.2f}'.format(Fstar0))
+                Fstar0 = 15.0
+            logging.info("Using Fstar0 of {:1.2f}".format(Fstar0))
             oLGX = np.zeros(10)
-            oLGX[:numDetectors] = 1./numDetectors
-            self.BSGLSetup = lalpulsar.CreateBSGLSetup(numDetectors,
-                                                       Fstar0,
-                                                       oLGX,
-                                                       True,
-                                                       1)
+            oLGX[:numDetectors] = 1.0 / numDetectors
+            self.BSGLSetup = lalpulsar.CreateBSGLSetup(
+                numDetectors, Fstar0, oLGX, True, 1
+            )
             self.twoFX = np.zeros(10)
-            self.whatToCompute = (self.whatToCompute +
-                                  lalpulsar.FSTATQ_2F_PER_DET)
+            self.whatToCompute = self.whatToCompute + lalpulsar.FSTATQ_2F_PER_DET
 
         if self.transientWindowType:
-            logging.info('Initialising transient parameters')
+            logging.info("Initialising transient parameters")
             self.windowRange = lalpulsar.transientWindowRange_t()
-            transientWindowTypes = {'none': lalpulsar.TRANSIENT_NONE,
-                                    'rect': lalpulsar.TRANSIENT_RECTANGULAR,
-                                    'exp':  lalpulsar.TRANSIENT_EXPONENTIAL}
+            transientWindowTypes = {
+                "none": lalpulsar.TRANSIENT_NONE,
+                "rect": lalpulsar.TRANSIENT_RECTANGULAR,
+                "exp": lalpulsar.TRANSIENT_EXPONENTIAL,
+            }
             if self.transientWindowType in transientWindowTypes:
                 self.windowRange.type = transientWindowTypes[self.transientWindowType]
             else:
                 raise ValueError(
-                    'Unknown window-type ({}) passed as input, [{}] allows.'
-                    .format(self.transientWindowType,
-                            ', '.join(transientWindowTypes)))
+                    "Unknown window-type ({}) passed as input, [{}] allows.".format(
+                        self.transientWindowType, ", ".join(transientWindowTypes)
+                    )
+                )
 
             # default spacing
-            self.Tsft = int(1.0/SFTCatalog.data[0].header.deltaF)
+            self.Tsft = int(1.0 / SFTCatalog.data[0].header.deltaF)
             self.windowRange.dt0 = self.Tsft
             self.windowRange.dtau = self.Tsft
 
@@ -648,15 +699,18 @@ class ComputeFstat(BaseSearchClass):
             if self.windowRange.type == lalpulsar.TRANSIENT_NONE:
                 self.windowRange.t0 = int(self.minStartTime)
                 self.windowRange.t0Band = 0
-                self.windowRange.tau = int(self.maxStartTime-self.minStartTime)
+                self.windowRange.tau = int(self.maxStartTime - self.minStartTime)
                 self.windowRange.tauBand = 0
             else:  # user-set bands and spacings
                 if self.t0Band is None:
                     self.windowRange.t0Band = 0
                 else:
                     if not isinstance(self.t0Band, int):
-                        logging.warn('Casting non-integer t0Band={} to int...'
-                                     .format(self.t0Band))
+                        logging.warn(
+                            "Casting non-integer t0Band={} to int...".format(
+                                self.t0Band
+                            )
+                        )
                         self.t0Band = int(self.t0Band)
                     self.windowRange.t0Band = self.t0Band
                     if self.dt0:
@@ -665,29 +719,47 @@ class ComputeFstat(BaseSearchClass):
                     self.windowRange.tauBand = 0
                 else:
                     if not isinstance(self.tauBand, int):
-                        logging.warn('Casting non-integer tauBand={} to int...'
-                                     .format(self.tauBand))
+                        logging.warn(
+                            "Casting non-integer tauBand={} to int...".format(
+                                self.tauBand
+                            )
+                        )
                         self.tauBand = int(self.tauBand)
                     self.windowRange.tauBand = self.tauBand
                     if self.dtau:
                         self.windowRange.dtau = self.dtau
                     if self.tauMin is None:
-                        self.windowRange.tau = int(2*self.Tsft)
+                        self.windowRange.tau = int(2 * self.Tsft)
                     else:
                         if not isinstance(self.tauMin, int):
-                            logging.warn('Casting non-integer tauMin={} to int...'
-                                         .format(self.tauMin))
+                            logging.warn(
+                                "Casting non-integer tauMin={} to int...".format(
+                                    self.tauMin
+                                )
+                            )
                             self.tauMin = int(self.tauMin)
                         self.windowRange.tau = self.tauMin
 
-            logging.info('Initialising transient FstatMap features...')
-            self.tCWFstatMapFeatures, self.gpu_context = (
-                tcw.init_transient_fstat_map_features(
-                    self.tCWFstatMapVersion == 'pycuda', self.cudaDeviceName))
-
-    def get_fullycoherent_twoF(self, tstart, tend, F0, F1, F2, Alpha, Delta,
-                               asini=None, period=None, ecc=None, tp=None,
-                               argp=None):
+            logging.info("Initialising transient FstatMap features...")
+            self.tCWFstatMapFeatures, self.gpu_context = tcw.init_transient_fstat_map_features(
+                self.tCWFstatMapVersion == "pycuda", self.cudaDeviceName
+            )
+
+    def get_fullycoherent_twoF(
+        self,
+        tstart,
+        tend,
+        F0,
+        F1,
+        F2,
+        Alpha,
+        Delta,
+        asini=None,
+        period=None,
+        ecc=None,
+        tp=None,
+        argp=None,
+    ):
         """ Returns twoF or ln(BSGL) fully-coherently at a single point """
         self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
         self.PulsarDopplerParams.Alpha = float(Alpha)
@@ -699,12 +771,13 @@ class ComputeFstat(BaseSearchClass):
             self.PulsarDopplerParams.tp = float(tp)
             self.PulsarDopplerParams.argp = float(argp)
 
-        lalpulsar.ComputeFstat(self.FstatResults,
-                               self.FstatInput,
-                               self.PulsarDopplerParams,
-                               1,
-                               self.whatToCompute
-                               )
+        lalpulsar.ComputeFstat(
+            self.FstatResults,
+            self.FstatInput,
+            self.PulsarDopplerParams,
+            1,
+            self.whatToCompute,
+        )
 
         if not self.transientWindowType:
             if self.BSGL is False:
@@ -713,9 +786,8 @@ class ComputeFstat(BaseSearchClass):
             twoF = np.float(self.FstatResults.twoF[0])
             self.twoFX[0] = self.FstatResults.twoFPerDet(0)
             self.twoFX[1] = self.FstatResults.twoFPerDet(1)
-            log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX,
-                                               self.BSGLSetup)
-            return log10_BSGL/np.log10(np.exp(1))
+            log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, self.BSGLSetup)
+            return log10_BSGL / np.log10(np.exp(1))
 
         self.windowRange.t0 = int(tstart)  # TYPE UINT4
         if self.windowRange.tauBand == 0:
@@ -724,14 +796,17 @@ class ComputeFstat(BaseSearchClass):
             self.windowRange.tau = int(tend - tstart)  # TYPE UINT4
 
         self.FstatMap, self.timingFstatMap = tcw.call_compute_transient_fstat_map(
-            self.tCWFstatMapVersion, self.tCWFstatMapFeatures,
-            self.FstatResults.multiFatoms[0], self.windowRange)
-        if self.tCWFstatMapVersion == 'lal':
+            self.tCWFstatMapVersion,
+            self.tCWFstatMapFeatures,
+            self.FstatResults.multiFatoms[0],
+            self.windowRange,
+        )
+        if self.tCWFstatMapVersion == "lal":
             F_mn = self.FstatMap.F_mn.data
         else:
             F_mn = self.FstatMap.F_mn
 
-        twoF = 2*np.max(F_mn)
+        twoF = 2 * np.max(F_mn)
         if self.BSGL is False:
             if np.isnan(twoF):
                 return 0
@@ -742,10 +817,12 @@ class ComputeFstat(BaseSearchClass):
         FstatResults_single.lenth = 1
         FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
         FS0 = lalpulsar.ComputeTransientFstatMap(
-            FstatResults_single.multiFatoms[0], self.windowRange, False)
+            FstatResults_single.multiFatoms[0], self.windowRange, False
+        )
         FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
         FS1 = lalpulsar.ComputeTransientFstatMap(
-            FstatResults_single.multiFatoms[0], self.windowRange, False)
+            FstatResults_single.multiFatoms[0], self.windowRange, False
+        )
 
         # for now, use the Doppler parameter with
         # multi-detector F maximised over t0,tau
@@ -754,17 +831,28 @@ class ComputeFstat(BaseSearchClass):
         # and return the maximum of that?
         idx_maxTwoF = np.argmax(F_mn)
 
-        self.twoFX[0] = 2*FS0.F_mn.data[idx_maxTwoF]
-        self.twoFX[1] = 2*FS1.F_mn.data[idx_maxTwoF]
-        log10_BSGL = lalpulsar.ComputeBSGL(
-                twoF, self.twoFX, self.BSGLSetup)
-
-        return log10_BSGL/np.log10(np.exp(1))
-
-    def calculate_twoF_cumulative(self, F0, F1, F2, Alpha, Delta, asini=None,
-                                  period=None, ecc=None, tp=None, argp=None,
-                                  tstart=None, tend=None, npoints=1000,
-                                  ):
+        self.twoFX[0] = 2 * FS0.F_mn.data[idx_maxTwoF]
+        self.twoFX[1] = 2 * FS1.F_mn.data[idx_maxTwoF]
+        log10_BSGL = lalpulsar.ComputeBSGL(twoF, self.twoFX, self.BSGLSetup)
+
+        return log10_BSGL / np.log10(np.exp(1))
+
+    def calculate_twoF_cumulative(
+        self,
+        F0,
+        F1,
+        F2,
+        Alpha,
+        Delta,
+        asini=None,
+        period=None,
+        ecc=None,
+        tp=None,
+        argp=None,
+        tstart=None,
+        tend=None,
+        npoints=1000,
+    ):
         """ Calculate the cumulative twoF along the obseration span
 
         Parameters
@@ -789,25 +877,36 @@ class ComputeFstat(BaseSearchClass):
         SFTminStartTime = self.SFT_timestamps[0]
         SFTmaxStartTime = self.SFT_timestamps[-1]
         tstart = np.max([SFTminStartTime, tstart])
-        min_tau = np.max([SFTminStartTime - tstart, 0]) + 3600*6
+        min_tau = np.max([SFTminStartTime - tstart, 0]) + 3600 * 6
         max_tau = SFTmaxStartTime - tstart
         taus = np.linspace(min_tau, max_tau, npoints)
         twoFs = []
         if not self.transientWindowType:
             # still call the transient-Fstat-map function, but using the full range
-            self.transientWindowType = 'none'
+            self.transientWindowType = "none"
             self.init_computefstatistic_single_point()
         for tau in taus:
             detstat = self.get_fullycoherent_twoF(
-                tstart=tstart, tend=tstart+tau, F0=F0, F1=F1, F2=F2,
-                Alpha=Alpha, Delta=Delta, asini=asini, period=period, ecc=ecc,
-                tp=tp, argp=argp)
+                tstart=tstart,
+                tend=tstart + tau,
+                F0=F0,
+                F1=F1,
+                F2=F2,
+                Alpha=Alpha,
+                Delta=Delta,
+                asini=asini,
+                period=period,
+                ecc=ecc,
+                tp=tp,
+                argp=argp,
+            )
             twoFs.append(detstat)
 
         return taus, np.array(twoFs)
 
-    def _calculate_predict_fstat_cumulative(self, N, label=None, outdir=None,
-                                            IFO=None, pfs_input=None):
+    def _calculate_predict_fstat_cumulative(
+        self, N, label=None, outdir=None, IFO=None, pfs_input=None
+    ):
         """ Calculates the predicted 2F and standard deviation cumulatively
 
         Parameters
@@ -828,23 +927,42 @@ class ComputeFstat(BaseSearchClass):
         """
 
         if pfs_input is None:
-            if os.path.isfile('{}/{}.loudest'.format(outdir, label)) is False:
-                raise ValueError(
-                    'Need a loudest file to add the predicted Fstat')
-            loudest = read_par(label=label, outdir=outdir, suffix='loudest')
-            pfs_input = {key: loudest[key] for key in
-                         ['h0', 'cosi', 'psi', 'Alpha', 'Delta', 'Freq']}
-        times = np.linspace(self.minStartTime, self.maxStartTime, N+1)[1:]
-        times = np.insert(times, 0, self.minStartTime + 86400/2.)
-        out = [predict_fstat(minStartTime=self.minStartTime, maxStartTime=t,
-                             sftfilepattern=self.sftfilepattern, IFO=IFO,
-                             **pfs_input) for t in times]
+            if os.path.isfile("{}/{}.loudest".format(outdir, label)) is False:
+                raise ValueError("Need a loudest file to add the predicted Fstat")
+            loudest = read_par(label=label, outdir=outdir, suffix="loudest")
+            pfs_input = {
+                key: loudest[key]
+                for key in ["h0", "cosi", "psi", "Alpha", "Delta", "Freq"]
+            }
+        times = np.linspace(self.minStartTime, self.maxStartTime, N + 1)[1:]
+        times = np.insert(times, 0, self.minStartTime + 86400 / 2.0)
+        out = [
+            predict_fstat(
+                minStartTime=self.minStartTime,
+                maxStartTime=t,
+                sftfilepattern=self.sftfilepattern,
+                IFO=IFO,
+                **pfs_input
+            )
+            for t in times
+        ]
         pfs, pfs_sigma = np.array(out).T
         return times, pfs, pfs_sigma
 
-    def plot_twoF_cumulative(self, label, outdir, add_pfs=False, N=15,
-                             injectSources=None, ax=None, c='k', savefig=True,
-                             title=None, plt_label=None, **kwargs):
+    def plot_twoF_cumulative(
+        self,
+        label,
+        outdir,
+        add_pfs=False,
+        N=15,
+        injectSources=None,
+        ax=None,
+        c="k",
+        savefig=True,
+        title=None,
+        plt_label=None,
+        **kwargs
+    ):
         """ Plot the twoF value cumulatively
 
         Parameters
@@ -877,14 +995,18 @@ class ComputeFstat(BaseSearchClass):
             fig, ax = plt.subplots()
         if injectSources:
             pfs_input = dict(
-                h0=injectSources['h0'], cosi=injectSources['cosi'],
-                psi=injectSources['psi'], Alpha=injectSources['Alpha'],
-                Delta=injectSources['Delta'], Freq=injectSources['fkdot'][0])
+                h0=injectSources["h0"],
+                cosi=injectSources["cosi"],
+                psi=injectSources["psi"],
+                Alpha=injectSources["Alpha"],
+                Delta=injectSources["Delta"],
+                Freq=injectSources["fkdot"][0],
+            )
         else:
             pfs_input = None
 
         taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
-        ax.plot(taus/86400., twoFs, label=plt_label, color=c)
+        ax.plot(taus / 86400.0, twoFs, label=plt_label, color=c)
         if len(self.detector_names) > 1:
             detector_names = self.detector_names
             detectors = self.detectors
@@ -892,84 +1014,115 @@ class ComputeFstat(BaseSearchClass):
                 self.detectors = d
                 self.init_computefstatistic_single_point()
                 taus, twoFs = self.calculate_twoF_cumulative(**kwargs)
-                ax.plot(taus/86400., twoFs, label='{}'.format(d),
-                        color=detector_colors[d.lower()])
+                ax.plot(
+                    taus / 86400.0,
+                    twoFs,
+                    label="{}".format(d),
+                    color=detector_colors[d.lower()],
+                )
             self.detectors = detectors
             self.detector_names = detector_names
 
         if add_pfs:
             times, pfs, pfs_sigma = self._calculate_predict_fstat_cumulative(
-                N=N, label=label, outdir=outdir, pfs_input=pfs_input)
+                N=N, label=label, outdir=outdir, pfs_input=pfs_input
+            )
             ax.fill_between(
-                (times-self.minStartTime)/86400., pfs-pfs_sigma, pfs+pfs_sigma,
+                (times - self.minStartTime) / 86400.0,
+                pfs - pfs_sigma,
+                pfs + pfs_sigma,
                 color=c,
-                label=(r'Predicted $\langle 2\mathcal{F} '
-                       r'\rangle\pm $ 1-$\sigma$ band'),
-                zorder=-10, alpha=0.2)
+                label=(
+                    r"Predicted $\langle 2\mathcal{F} " r"\rangle\pm $ 1-$\sigma$ band"
+                ),
+                zorder=-10,
+                alpha=0.2,
+            )
             if len(self.detector_names) > 1:
                 for d in self.detector_names:
                     out = self._calculate_predict_fstat_cumulative(
-                        N=N, label=label, outdir=outdir, IFO=d.upper(),
-                        pfs_input=pfs_input)
+                        N=N,
+                        label=label,
+                        outdir=outdir,
+                        IFO=d.upper(),
+                        pfs_input=pfs_input,
+                    )
                     times, pfs, pfs_sigma = out
                     ax.fill_between(
-                        (times-self.minStartTime)/86400., pfs-pfs_sigma,
-                        pfs+pfs_sigma, color=detector_colors[d.lower()],
+                        (times - self.minStartTime) / 86400.0,
+                        pfs - pfs_sigma,
+                        pfs + pfs_sigma,
+                        color=detector_colors[d.lower()],
                         alpha=0.5,
                         label=(
-                            'Predicted $2\mathcal{{F}}$ 1-$\sigma$ band ({})'
-                            .format(d.upper())),
-                        zorder=-10)
-
-        ax.set_xlabel(r'Days from $t_{{\rm start}}={:.0f}$'.format(
-            kwargs['tstart']))
+                            "Predicted $2\mathcal{{F}}$ 1-$\sigma$ band ({})".format(
+                                d.upper()
+                            )
+                        ),
+                        zorder=-10,
+                    )
+
+        ax.set_xlabel(r"Days from $t_{{\rm start}}={:.0f}$".format(kwargs["tstart"]))
         if self.BSGL:
-            ax.set_ylabel(r'$\log_{10}(\mathrm{BSGL})_{\rm cumulative}$')
+            ax.set_ylabel(r"$\log_{10}(\mathrm{BSGL})_{\rm cumulative}$")
         else:
-            ax.set_ylabel(r'$\widetilde{2\mathcal{F}}_{\rm cumulative}$')
-        ax.set_xlim(0, taus[-1]/86400)
+            ax.set_ylabel(r"$\widetilde{2\mathcal{F}}_{\rm cumulative}$")
+        ax.set_xlim(0, taus[-1] / 86400)
         if plt_label:
             ax.legend(frameon=False, loc=2, fontsize=6)
         if title:
             ax.set_title(title)
         if savefig:
             plt.tight_layout()
-            plt.savefig('{}/{}_twoFcumulative.png'.format(outdir, label))
+            plt.savefig("{}/{}_twoFcumulative.png".format(outdir, label))
             return taus, twoFs
         else:
             return ax
 
-    def get_full_CFSv2_output(self, tstart, tend, F0, F1, F2, Alpha, Delta,
-                              tref):
+    def get_full_CFSv2_output(self, tstart, tend, F0, F1, F2, Alpha, Delta, tref):
         """ Basic wrapper around CFSv2 to get the full (h0..) output """
         cl_CFSv2 = "lalapps_ComputeFstatistic_v2 --minStartTime={} --maxStartTime={} --Freq={} --f1dot={} --f2dot={} --Alpha={} --Delta={} --refTime={} --DataFiles='{}' --outputLoudest='{}' --ephemEarth={} --ephemSun={}"
         LoudestFile = "loudest.temp"
-        helper_functions.run_commandline(cl_CFSv2.format(
-            tstart, tend, F0, F1, F2, Alpha, Delta, tref, self.sftfilepattern,
-            LoudestFile, self.earth_ephem, self.sun_ephem))
-        loudest = read_par(LoudestFile, return_type='dict')
+        helper_functions.run_commandline(
+            cl_CFSv2.format(
+                tstart,
+                tend,
+                F0,
+                F1,
+                F2,
+                Alpha,
+                Delta,
+                tref,
+                self.sftfilepattern,
+                LoudestFile,
+                self.earth_ephem,
+                self.sun_ephem,
+            )
+        )
+        loudest = read_par(LoudestFile, return_type="dict")
         os.remove(LoudestFile)
         return loudest
 
-    def write_atoms_to_file(self, fnamebase=''):
-        multiFatoms = getattr(self.FstatResults, 'multiFatoms', None)
+    def write_atoms_to_file(self, fnamebase=""):
+        multiFatoms = getattr(self.FstatResults, "multiFatoms", None)
         if multiFatoms and multiFatoms[0]:
-            dopplerName = lalpulsar.PulsarDopplerParams2String ( self.PulsarDopplerParams )
-            #fnameAtoms = os.path.join(self.outdir,'Fstatatoms_%s.dat' % dopplerName)
-            fnameAtoms = fnamebase + '_Fstatatoms_%s.dat' % dopplerName
-            fo = lal.FileOpen(fnameAtoms, 'w')
-            lalpulsar.write_MultiFstatAtoms_to_fp ( fo, multiFatoms[0] )
-            del fo # instead of lal.FileClose() which is not SWIG-exported
+            dopplerName = lalpulsar.PulsarDopplerParams2String(self.PulsarDopplerParams)
+            # fnameAtoms = os.path.join(self.outdir,'Fstatatoms_%s.dat' % dopplerName)
+            fnameAtoms = fnamebase + "_Fstatatoms_%s.dat" % dopplerName
+            fo = lal.FileOpen(fnameAtoms, "w")
+            lalpulsar.write_MultiFstatAtoms_to_fp(fo, multiFatoms[0])
+            del fo  # instead of lal.FileClose() which is not SWIG-exported
         else:
-            raise RuntimeError('Cannot print atoms vector to file: no FstatResults.multiFatoms, or it is None!')
-
+            raise RuntimeError(
+                "Cannot print atoms vector to file: no FstatResults.multiFatoms, or it is None!"
+            )
 
     def __del__(self):
         """
         In pyCuda case without autoinit,
         we need to make sure the context is removed at the end
         """
-        if hasattr(self,'gpu_context') and self.gpu_context:
+        if hasattr(self, "gpu_context") and self.gpu_context:
             self.gpu_context.detach()
 
 
@@ -977,11 +1130,24 @@ class SemiCoherentSearch(ComputeFstat):
     """ A semi-coherent search """
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, tref, nsegs=None, sftfilepattern=None,
-                 binary=False, BSGL=False, minStartTime=None,
-                 maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, injectSources=None, assumeSqrtSX=None,
-                 SSBprec=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        tref,
+        nsegs=None,
+        sftfilepattern=None,
+        binary=False,
+        BSGL=False,
+        minStartTime=None,
+        maxStartTime=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+        SSBprec=None,
+    ):
         """
         Parameters
         ----------
@@ -1000,38 +1166,55 @@ class SemiCoherentSearch(ComputeFstat):
 
         self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
         self.set_ephemeris_files()
-        self.transientWindowType = 'rect'
-        self.t0Band  = None
+        self.transientWindowType = "rect"
+        self.t0Band = None
         self.tauBand = None
-        self.tCWFstatMapVersion = 'lal'
+        self.tCWFstatMapVersion = "lal"
         self.cudaDeviceName = None
         self.init_computefstatistic_single_point()
         self.init_semicoherent_parameters()
 
     def init_semicoherent_parameters(self):
-        logging.info(('Initialising semicoherent parameters from {} to {} in'
-                      ' {} segments').format(
-            self.minStartTime, self.maxStartTime, self.nsegs))
-        self.transientWindowType = 'rect'
-        self.whatToCompute = lalpulsar.FSTATQ_2F+lalpulsar.FSTATQ_ATOMS_PER_DET
-        self.tboundaries = np.linspace(self.minStartTime, self.maxStartTime,
-                                       self.nsegs+1)
+        logging.info(
+            (
+                "Initialising semicoherent parameters from {} to {} in" " {} segments"
+            ).format(self.minStartTime, self.maxStartTime, self.nsegs)
+        )
+        self.transientWindowType = "rect"
+        self.whatToCompute = lalpulsar.FSTATQ_2F + lalpulsar.FSTATQ_ATOMS_PER_DET
+        self.tboundaries = np.linspace(
+            self.minStartTime, self.maxStartTime, self.nsegs + 1
+        )
         self.Tcoh = self.tboundaries[1] - self.tboundaries[0]
 
-        if hasattr(self, 'SFT_timestamps'):
+        if hasattr(self, "SFT_timestamps"):
             if self.tboundaries[0] < self.SFT_timestamps[0]:
                 logging.debug(
-                    'Semi-coherent start time {} before first SFT timestamp {}'
-                    .format(self.tboundaries[0], self.SFT_timestamps[0]))
+                    "Semi-coherent start time {} before first SFT timestamp {}".format(
+                        self.tboundaries[0], self.SFT_timestamps[0]
+                    )
+                )
             if self.tboundaries[-1] > self.SFT_timestamps[-1]:
                 logging.debug(
-                    'Semi-coherent end time {} after last SFT timestamp {}'
-                    .format(self.tboundaries[-1], self.SFT_timestamps[-1]))
+                    "Semi-coherent end time {} after last SFT timestamp {}".format(
+                        self.tboundaries[-1], self.SFT_timestamps[-1]
+                    )
+                )
 
     def get_semicoherent_twoF(
-            self, F0, F1, F2, Alpha, Delta, asini=None,
-            period=None, ecc=None, tp=None, argp=None,
-            record_segments=False):
+        self,
+        F0,
+        F1,
+        F2,
+        Alpha,
+        Delta,
+        asini=None,
+        period=None,
+        ecc=None,
+        tp=None,
+        argp=None,
+        record_segments=False,
+    ):
         """ Returns twoF or ln(BSGL) semi-coherently at a single point """
 
         self.PulsarDopplerParams.fkdot = np.array([F0, F1, F2, 0, 0, 0, 0])
@@ -1044,14 +1227,15 @@ class SemiCoherentSearch(ComputeFstat):
             self.PulsarDopplerParams.tp = float(tp)
             self.PulsarDopplerParams.argp = float(argp)
 
-        lalpulsar.ComputeFstat(self.FstatResults,
-                               self.FstatInput,
-                               self.PulsarDopplerParams,
-                               1,
-                               self.whatToCompute
-                               )
+        lalpulsar.ComputeFstat(
+            self.FstatResults,
+            self.FstatInput,
+            self.PulsarDopplerParams,
+            1,
+            self.whatToCompute,
+        )
 
-        #if not self.transientWindowType:
+        # if not self.transientWindowType:
         #    if self.BSGL is False:
         #        return self.FstatResults.twoF[0]
         #    twoF = np.float(self.FstatResults.twoF[0])
@@ -1078,27 +1262,31 @@ class SemiCoherentSearch(ComputeFstat):
         self.windowRange.t0 = int(tstart)  # TYPE UINT4
 
         FS = lalpulsar.ComputeTransientFstatMap(
-            self.FstatResults.multiFatoms[0], self.windowRange, False)
+            self.FstatResults.multiFatoms[0], self.windowRange, False
+        )
 
         if self.BSGL is False:
-            d_detStat = 2*FS.F_mn.data[0][0]
+            d_detStat = 2 * FS.F_mn.data[0][0]
         else:
             FstatResults_single = copy.copy(self.FstatResults)
             FstatResults_single.lenth = 1
             FstatResults_single.data = self.FstatResults.multiFatoms[0].data[0]
             FS0 = lalpulsar.ComputeTransientFstatMap(
-                FstatResults_single.multiFatoms[0], self.windowRange, False)
+                FstatResults_single.multiFatoms[0], self.windowRange, False
+            )
             FstatResults_single.data = self.FstatResults.multiFatoms[0].data[1]
             FS1 = lalpulsar.ComputeTransientFstatMap(
-                FstatResults_single.multiFatoms[0], self.windowRange, False)
+                FstatResults_single.multiFatoms[0], self.windowRange, False
+            )
 
-            self.twoFX[0] = 2*FS0.F_mn.data[0][0]
-            self.twoFX[1] = 2*FS1.F_mn.data[0][0]
+            self.twoFX[0] = 2 * FS0.F_mn.data[0][0]
+            self.twoFX[1] = 2 * FS1.F_mn.data[0][0]
             log10_BSGL = lalpulsar.ComputeBSGL(
-                    2*FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup)
-            d_detStat = log10_BSGL/np.log10(np.exp(1))
+                2 * FS.F_mn.data[0][0], self.twoFX, self.BSGLSetup
+            )
+            d_detStat = log10_BSGL / np.log10(np.exp(1))
         if np.isnan(d_detStat):
-            logging.debug('NaNs in semi-coherent twoF treated as zero')
+            logging.debug("NaNs in semi-coherent twoF treated as zero")
             d_detStat = 0
 
         return d_detStat
@@ -1114,10 +1302,24 @@ class SemiCoherentGlitchSearch(ComputeFstat):
     """
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, tref, minStartTime, maxStartTime,
-                 nglitch=1, sftfilepattern=None, theta0_idx=0, BSGL=False,
-                 minCoverFreq=None, maxCoverFreq=None, assumeSqrtSX=None,
-                 detectors=None, SSBprec=None, injectSources=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        tref,
+        minStartTime,
+        maxStartTime,
+        nglitch=1,
+        sftfilepattern=None,
+        theta0_idx=0,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        assumeSqrtSX=None,
+        detectors=None,
+        SSBprec=None,
+        injectSources=None,
+    ):
         """
         Parameters
         ----------
@@ -1141,38 +1343,45 @@ class SemiCoherentGlitchSearch(ComputeFstat):
 
         self.fs_file_name = "{}/{}_FS.dat".format(self.outdir, self.label)
         self.set_ephemeris_files()
-        self.transientWindowType = 'rect'
-        self.t0Band  = None
+        self.transientWindowType = "rect"
+        self.t0Band = None
         self.tauBand = None
-        self.tCWFstatMapVersion = 'lal'
+        self.tCWFstatMapVersion = "lal"
         self.cudaDeviceName = None
-        self.binary  = False
+        self.binary = False
         self.init_computefstatistic_single_point()
 
     def get_semicoherent_nglitch_twoF(self, F0, F1, F2, Alpha, Delta, *args):
         """ Returns the semi-coherent glitch summed twoF """
 
         args = list(args)
-        tboundaries = ([self.minStartTime] + args[-self.nglitch:]
-                       + [self.maxStartTime])
-        delta_F0s = args[-3*self.nglitch:-2*self.nglitch]
-        delta_F1s = args[-2*self.nglitch:-self.nglitch]
+        tboundaries = [self.minStartTime] + args[-self.nglitch :] + [self.maxStartTime]
+        delta_F0s = args[-3 * self.nglitch : -2 * self.nglitch]
+        delta_F1s = args[-2 * self.nglitch : -self.nglitch]
         delta_F2 = np.zeros(len(delta_F0s))
         delta_phi = np.zeros(len(delta_F0s))
         theta = [0, F0, F1, F2]
         delta_thetas = np.atleast_2d(
-                np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T)
+            np.array([delta_phi, delta_F0s, delta_F1s, delta_F2]).T
+        )
 
-        thetas = self._calculate_thetas(theta, delta_thetas, tboundaries,
-                                        theta0_idx=self.theta0_idx)
+        thetas = self._calculate_thetas(
+            theta, delta_thetas, tboundaries, theta0_idx=self.theta0_idx
+        )
 
         twoFSum = 0
         for i, theta_i_at_tref in enumerate(thetas):
-            ts, te = tboundaries[i], tboundaries[i+1]
+            ts, te = tboundaries[i], tboundaries[i + 1]
             if te - ts > 1800:
                 twoFVal = self.get_fullycoherent_twoF(
-                    ts, te, theta_i_at_tref[1], theta_i_at_tref[2],
-                    theta_i_at_tref[3], Alpha, Delta)
+                    ts,
+                    te,
+                    theta_i_at_tref[1],
+                    theta_i_at_tref[2],
+                    theta_i_at_tref[3],
+                    Alpha,
+                    Delta,
+                )
                 twoFSum += twoFVal
 
         if np.isfinite(twoFSum):
@@ -1180,8 +1389,9 @@ class SemiCoherentGlitchSearch(ComputeFstat):
         else:
             return -np.inf
 
-    def compute_glitch_fstat_single(self, F0, F1, F2, Alpha, Delta, delta_F0,
-                                    delta_F1, tglitch):
+    def compute_glitch_fstat_single(
+        self, F0, F1, F2, Alpha, Delta, delta_F0, delta_F1, tglitch
+    ):
         """ Returns the semi-coherent glitch summed twoF for nglitch=1
 
         Note: OBSOLETE, used only for testing
@@ -1194,18 +1404,24 @@ class SemiCoherentGlitchSearch(ComputeFstat):
         theta_at_glitch = self._shift_coefficients(theta, tglitch - tref)
         theta_post_glitch_at_glitch = theta_at_glitch + delta_theta
         theta_post_glitch = self._shift_coefficients(
-            theta_post_glitch_at_glitch, tref - tglitch)
+            theta_post_glitch_at_glitch, tref - tglitch
+        )
 
         twoFsegA = self.get_fullycoherent_twoF(
-            self.minStartTime, tglitch, theta[0], theta[1], theta[2], Alpha,
-            Delta)
+            self.minStartTime, tglitch, theta[0], theta[1], theta[2], Alpha, Delta
+        )
 
         if tglitch == self.maxStartTime:
             return twoFsegA
 
         twoFsegB = self.get_fullycoherent_twoF(
-            tglitch, self.maxStartTime, theta_post_glitch[0],
-            theta_post_glitch[1], theta_post_glitch[2], Alpha,
-            Delta)
+            tglitch,
+            self.maxStartTime,
+            theta_post_glitch[0],
+            theta_post_glitch[1],
+            theta_post_glitch[2],
+            Alpha,
+            Delta,
+        )
 
         return twoFsegA + twoFsegB
diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index 0d61e3f41d09b07182a755d7da83f45e45f395ef..46d0a9ed7b5b4483af88cd73393af7889196d12a 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -15,28 +15,70 @@ import matplotlib.pyplot as plt
 from scipy.special import logsumexp
 
 import pyfstat.helper_functions as helper_functions
-from pyfstat.core import (BaseSearchClass, ComputeFstat,
-                          SemiCoherentGlitchSearch, SemiCoherentSearch, tqdm,
-                          args, read_par)
+from pyfstat.core import (
+    BaseSearchClass,
+    ComputeFstat,
+    SemiCoherentGlitchSearch,
+    SemiCoherentSearch,
+    tqdm,
+    args,
+    read_par,
+)
 import lalpulsar
 import lal
 
 
 class GridSearch(BaseSearchClass):
     """ Gridded search using ComputeFstat """
-    tex_labels = {'F0': '$f$', 'F1': '$\dot{f}$', 'F2': '$\ddot{f}$',
-                  'Alpha': r'$\alpha$', 'Delta': r'$\delta$'}
-    tex_labels0 = {'F0': '$-f_0$', 'F1': '$-\dot{f}_0$', 'F2': '$-\ddot{f}_0$',
-                   'Alpha': r'$-\alpha_0$', 'Delta': r'$-\delta_0$'}
-    search_labels = ['minStartTime', 'maxStartTime', 'F0s', 'F1s', 'F2s',
-                     'Alphas', 'Deltas']
+
+    tex_labels = {
+        "F0": "$f$",
+        "F1": "$\dot{f}$",
+        "F2": "$\ddot{f}$",
+        "Alpha": r"$\alpha$",
+        "Delta": r"$\delta$",
+    }
+    tex_labels0 = {
+        "F0": "$-f_0$",
+        "F1": "$-\dot{f}_0$",
+        "F2": "$-\ddot{f}_0$",
+        "Alpha": r"$-\alpha_0$",
+        "Delta": r"$-\delta_0$",
+    }
+    search_labels = [
+        "minStartTime",
+        "maxStartTime",
+        "F0s",
+        "F1s",
+        "F2s",
+        "Alphas",
+        "Deltas",
+    ]
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
-                 Deltas, tref=None, minStartTime=None, maxStartTime=None,
-                 nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, SSBprec=None, injectSources=None,
-                 input_arrays=False, assumeSqrtSX=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        sftfilepattern,
+        F0s,
+        F1s,
+        F2s,
+        Alphas,
+        Deltas,
+        tref=None,
+        minStartTime=None,
+        maxStartTime=None,
+        nsegs=1,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        SSBprec=None,
+        injectSources=None,
+        input_arrays=False,
+        assumeSqrtSX=None,
+    ):
         """
         Parameters
         ----------
@@ -64,34 +106,47 @@ class GridSearch(BaseSearchClass):
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self.set_out_file()
-        self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
-        self.search_keys = [x+'s' for x in self.keys[2:]]
+        self.keys = ["_", "_", "F0", "F1", "F2", "Alpha", "Delta"]
+        self.search_keys = [x + "s" for x in self.keys[2:]]
         for k in self.search_keys:
             setattr(self, k, np.atleast_1d(getattr(self, k)))
 
     def inititate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         if self.nsegs == 1:
             self.search = ComputeFstat(
-                tref=self.tref, sftfilepattern=self.sftfilepattern,
-                minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
+                tref=self.tref,
+                sftfilepattern=self.sftfilepattern,
+                minCoverFreq=self.minCoverFreq,
+                maxCoverFreq=self.maxCoverFreq,
                 detectors=self.detectors,
-                minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-                BSGL=self.BSGL, SSBprec=self.SSBprec,
+                minStartTime=self.minStartTime,
+                maxStartTime=self.maxStartTime,
+                BSGL=self.BSGL,
+                SSBprec=self.SSBprec,
                 injectSources=self.injectSources,
-                assumeSqrtSX=self.assumeSqrtSX)
+                assumeSqrtSX=self.assumeSqrtSX,
+            )
             self.search.get_det_stat = self.search.get_fullycoherent_twoF
         else:
             self.search = SemiCoherentSearch(
-                label=self.label, outdir=self.outdir, tref=self.tref,
-                nsegs=self.nsegs, sftfilepattern=self.sftfilepattern,
-                BSGL=self.BSGL, minStartTime=self.minStartTime,
-                maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
-                maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
-                injectSources=self.injectSources)
+                label=self.label,
+                outdir=self.outdir,
+                tref=self.tref,
+                nsegs=self.nsegs,
+                sftfilepattern=self.sftfilepattern,
+                BSGL=self.BSGL,
+                minStartTime=self.minStartTime,
+                maxStartTime=self.maxStartTime,
+                minCoverFreq=self.minCoverFreq,
+                maxCoverFreq=self.maxCoverFreq,
+                detectors=self.detectors,
+                injectSources=self.injectSources,
+            )
 
             def cut_out_tstart_tend(*vals):
                 return self.search.get_semicoherent_twoF(*vals[2:])
+
             self.search.get_det_stat = cut_out_tstart_tend
 
     def get_array_from_tuple(self, x):
@@ -100,7 +155,7 @@ class GridSearch(BaseSearchClass):
         elif len(x) == 3 and self.input_arrays is False:
             return np.arange(x[0], x[1], x[2])
         else:
-            logging.info('Using tuple as is')
+            logging.info("Using tuple as is")
             return np.array(x)
 
     def get_input_data_array(self):
@@ -108,14 +163,15 @@ class GridSearch(BaseSearchClass):
         coord_arrays = []
         for sl in self.search_labels:
             coord_arrays.append(
-                self.get_array_from_tuple(np.atleast_1d(getattr(self, sl))))
+                self.get_array_from_tuple(np.atleast_1d(getattr(self, sl)))
+            )
         self.coord_arrays = coord_arrays
         self.total_iterations = np.prod([len(ca) for ca in coord_arrays])
 
         if args.clean is False:
             input_data = []
             for vals in itertools.product(*coord_arrays):
-                    input_data.append(vals)
+                input_data.append(vals)
             self.input_data = np.array(input_data)
 
     def check_old_data_is_okay_to_use(self):
@@ -123,28 +179,37 @@ class GridSearch(BaseSearchClass):
             return False
         if os.path.isfile(self.out_file) is False:
             logging.info(
-                'No old data found in "{:s}", continuing with grid search'
-                .format(self.out_file))
+                'No old data found in "{:s}", continuing with grid search'.format(
+                    self.out_file
+                )
+            )
             return False
         if self.sftfilepattern is not None:
-            oldest_sft = min([os.path.getmtime(f) for f in
-                              self._get_list_of_matching_sfts()])
+            oldest_sft = min(
+                [os.path.getmtime(f) for f in self._get_list_of_matching_sfts()]
+            )
             if os.path.getmtime(self.out_file) < oldest_sft:
-                logging.info('Search output data outdates sft files,'
-                             + ' continuing with grid search')
+                logging.info(
+                    "Search output data outdates sft files,"
+                    + " continuing with grid search"
+                )
                 return False
 
-        data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=' '))
-        if np.all(data[:, 0: len(self.coord_arrays)] ==
-                  self.input_data[:, 0:len(self.coord_arrays)]):
+        data = np.atleast_2d(np.genfromtxt(self.out_file, delimiter=" "))
+        if np.all(
+            data[:, 0 : len(self.coord_arrays)]
+            == self.input_data[:, 0 : len(self.coord_arrays)]
+        ):
             logging.info(
                 'Old data found in "{:s}" with matching input, no search '
-                'performed'.format(self.out_file))
+                "performed".format(self.out_file)
+            )
             return data
         else:
             logging.info(
                 'Old data found in "{:s}", input differs, continuing with '
-                'grid search'.format(self.out_file))
+                "grid search".format(self.out_file)
+            )
             return False
         return False
 
@@ -161,12 +226,11 @@ class GridSearch(BaseSearchClass):
                 self.data = old_data
                 return
 
-        if hasattr(self, 'search') is False:
+        if hasattr(self, "search") is False:
             self.inititate_search_object()
 
         data = []
-        for vals in tqdm(iterable,
-                         total=getattr(self, 'total_iterations', None)):
+        for vals in tqdm(iterable, total=getattr(self, "total_iterations", None)):
             detstat = self.search.get_det_stat(*vals)
             thisCand = list(vals) + [detstat]
             data.append(thisCand)
@@ -179,28 +243,32 @@ class GridSearch(BaseSearchClass):
             self.data = data
 
     def get_header(self):
-        header = ';'.join(['date:{}'.format(str(datetime.datetime.now())),
-                           'user:{}'.format(getpass.getuser()),
-                           'hostname:{}'.format(socket.gethostname())])
-        header += '\n' + ' '.join(self.keys)
+        header = ";".join(
+            [
+                "date:{}".format(str(datetime.datetime.now())),
+                "user:{}".format(getpass.getuser()),
+                "hostname:{}".format(socket.gethostname()),
+            ]
+        )
+        header += "\n" + " ".join(self.keys)
         return header
 
     def save_array_to_disk(self, data):
-        logging.info('Saving data to {}'.format(self.out_file))
+        logging.info("Saving data to {}".format(self.out_file))
         header = self.get_header()
-        np.savetxt(self.out_file, data, delimiter=' ', header=header)
+        np.savetxt(self.out_file, data, delimiter=" ", header=header)
 
     def convert_F0_to_mismatch(self, F0, F0hat, Tseg):
         DeltaF0 = F0[1] - F0[0]
-        m_spacing = (np.pi*Tseg*DeltaF0)**2 / 12.
+        m_spacing = (np.pi * Tseg * DeltaF0) ** 2 / 12.0
         N = len(F0)
-        return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
+        return np.arange(-N * m_spacing / 2.0, N * m_spacing / 2.0, m_spacing)
 
     def convert_F1_to_mismatch(self, F1, F1hat, Tseg):
         DeltaF1 = F1[1] - F1[0]
-        m_spacing = (np.pi*Tseg**2*DeltaF1)**2 / 720.
+        m_spacing = (np.pi * Tseg ** 2 * DeltaF1) ** 2 / 720.0
         N = len(F1)
-        return np.arange(-N*m_spacing/2., N*m_spacing/2., m_spacing)
+        return np.arange(-N * m_spacing / 2.0, N * m_spacing / 2.0, m_spacing)
 
     def add_mismatch_to_ax(self, ax, x, y, xkey, ykey, xhat, yhat, Tseg):
         axX = ax.twiny()
@@ -208,16 +276,24 @@ class GridSearch(BaseSearchClass):
         axY = ax.twinx()
         axY.zorder = -10
 
-        if xkey == 'F0':
+        if xkey == "F0":
             m = self.convert_F0_to_mismatch(x, xhat, Tseg)
             axX.set_xlim(m[0], m[-1])
 
-        if ykey == 'F1':
+        if ykey == "F1":
             m = self.convert_F1_to_mismatch(y, yhat, Tseg)
             axY.set_ylim(m[0], m[-1])
 
-    def plot_1D(self, xkey, ax=None, x0=None, xrescale=1, savefig=True,
-                xlabel=None, ylabel='$\widetilde{2\mathcal{F}}$'):
+    def plot_1D(
+        self,
+        xkey,
+        ax=None,
+        x0=None,
+        xrescale=1,
+        savefig=True,
+        xlabel=None,
+        ylabel="$\widetilde{2\mathcal{F}}$",
+    ):
         if ax is None:
             fig, ax = plt.subplots()
         xidx = self.keys.index(xkey)
@@ -228,7 +304,7 @@ class GridSearch(BaseSearchClass):
         z = self.data[:, -1]
         ax.plot(x, z)
         if x0:
-            ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
+            ax.set_xlabel(self.tex_labels[xkey] + self.tex_labels0[xkey])
         else:
             ax.set_xlabel(self.tex_labels[xkey])
 
@@ -238,15 +314,34 @@ class GridSearch(BaseSearchClass):
         ax.set_ylabel(ylabel)
         if savefig:
             fig.tight_layout()
-            fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
+            fig.savefig("{}/{}_1D.png".format(self.outdir, self.label))
         else:
             return ax
 
-    def plot_2D(self, xkey, ykey, ax=None, save=True, vmin=None, vmax=None,
-                add_mismatch=None, xN=None, yN=None, flat_keys=[],
-                rel_flat_idxs=[], flatten_method=np.max, title=None,
-                predicted_twoF=None, cm=None, cbarkwargs={}, x0=None, y0=None,
-                colorbar=False, xrescale=1, yrescale=1):
+    def plot_2D(
+        self,
+        xkey,
+        ykey,
+        ax=None,
+        save=True,
+        vmin=None,
+        vmax=None,
+        add_mismatch=None,
+        xN=None,
+        yN=None,
+        flat_keys=[],
+        rel_flat_idxs=[],
+        flatten_method=np.max,
+        title=None,
+        predicted_twoF=None,
+        cm=None,
+        cbarkwargs={},
+        x0=None,
+        y0=None,
+        colorbar=False,
+        xrescale=1,
+        yrescale=1,
+    ):
         """ Plots a 2D grid of 2F values
 
         Parameters
@@ -265,10 +360,10 @@ class GridSearch(BaseSearchClass):
 
         x = np.unique(self.data[:, xidx])
         if x0:
-            x = x-x0
+            x = x - x0
         y = np.unique(self.data[:, yidx])
         if y0:
-            y = y-y0
+            y = y - y0
         flat_vals = [np.unique(self.data[:, j]) for j in flat_idxs]
         z = self.data[:, -1]
 
@@ -288,22 +383,23 @@ class GridSearch(BaseSearchClass):
                 cm = plt.cm.viridis
 
         pax = ax.pcolormesh(
-            X*xrescale, Y*yrescale, Z, cmap=cm, vmin=vmin, vmax=vmax)
+            X * xrescale, Y * yrescale, Z, cmap=cm, vmin=vmin, vmax=vmax
+        )
         if colorbar:
             cb = plt.colorbar(pax, ax=ax, **cbarkwargs)
-            cb.set_label('$2\mathcal{F}$')
+            cb.set_label("$2\mathcal{F}$")
 
         if add_mismatch:
             self.add_mismatch_to_ax(ax, x, y, xkey, ykey, *add_mismatch)
 
-        ax.set_xlim(x[0]*xrescale, x[-1]*xrescale)
-        ax.set_ylim(y[0]*yrescale, y[-1]*yrescale)
+        ax.set_xlim(x[0] * xrescale, x[-1] * xrescale)
+        ax.set_ylim(y[0] * yrescale, y[-1] * yrescale)
         if x0:
-            ax.set_xlabel(self.tex_labels[xkey]+self.tex_labels0[xkey])
+            ax.set_xlabel(self.tex_labels[xkey] + self.tex_labels0[xkey])
         else:
             ax.set_xlabel(self.tex_labels[xkey])
         if y0:
-            ax.set_ylabel(self.tex_labels[ykey]+self.tex_labels0[ykey])
+            ax.set_ylabel(self.tex_labels[ykey] + self.tex_labels0[ykey])
         else:
             ax.set_ylabel(self.tex_labels[ykey])
 
@@ -317,7 +413,7 @@ class GridSearch(BaseSearchClass):
 
         if save:
             fig.tight_layout()
-            fig.savefig('{}/{}_2D.png'.format(self.outdir, self.label))
+            fig.savefig("{}/{}_2D.png".format(self.outdir, self.label))
         else:
             return ax
 
@@ -335,45 +431,75 @@ class GridSearch(BaseSearchClass):
         twoF = self.data[:, -1]
         idx = np.argmax(twoF)
         v = self.data[idx, :]
-        d = OrderedDict(minStartTime=v[0], maxStartTime=v[1], F0=v[2], F1=v[3],
-                        F2=v[4], Alpha=v[5], Delta=v[6], twoF=v[7])
+        d = OrderedDict(
+            minStartTime=v[0],
+            maxStartTime=v[1],
+            F0=v[2],
+            F1=v[3],
+            F2=v[4],
+            Alpha=v[5],
+            Delta=v[6],
+            twoF=v[7],
+        )
         return d
 
     def print_max_twoF(self):
         d = self.get_max_twoF()
-        print('Max twoF values for {}:'.format(self.label))
+        print("Max twoF values for {}:".format(self.label))
         for k, v in d.items():
-            print('  {}={}'.format(k, v))
+            print("  {}={}".format(k, v))
 
     def set_out_file(self, extra_label=None):
         if self.detectors:
-            dets = self.detectors.replace(',', '')
+            dets = self.detectors.replace(",", "")
         else:
-            dets = 'NA'
+            dets = "NA"
         if extra_label:
-            self.out_file = '{}/{}_{}_{}_{}.txt'.format(
-                self.outdir, self.label, dets, type(self).__name__,
-                extra_label)
+            self.out_file = "{}/{}_{}_{}_{}.txt".format(
+                self.outdir, self.label, dets, type(self).__name__, extra_label
+            )
         else:
-            self.out_file = '{}/{}_{}_{}.txt'.format(
-                self.outdir, self.label, dets, type(self).__name__)
+            self.out_file = "{}/{}_{}_{}.txt".format(
+                self.outdir, self.label, dets, type(self).__name__
+            )
 
 
 class TransientGridSearch(GridSearch):
     """ Gridded transient-continous search using ComputeFstat """
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
-                 Deltas, tref=None, minStartTime=None, maxStartTime=None,
-                 BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, SSBprec=None, injectSources=None,
-                 input_arrays=False, assumeSqrtSX=None,
-                 transientWindowType=None, t0Band=None, tauBand=None,
-                 tauMin = None,
-                 dt0=None, dtau=None,
-                 outputTransientFstatMap=False,
-                 outputAtoms=False,
-                 tCWFstatMapVersion='lal', cudaDeviceName=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        sftfilepattern,
+        F0s,
+        F1s,
+        F2s,
+        Alphas,
+        Deltas,
+        tref=None,
+        minStartTime=None,
+        maxStartTime=None,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        SSBprec=None,
+        injectSources=None,
+        input_arrays=False,
+        assumeSqrtSX=None,
+        transientWindowType=None,
+        t0Band=None,
+        tauBand=None,
+        tauMin=None,
+        dt0=None,
+        dtau=None,
+        outputTransientFstatMap=False,
+        outputAtoms=False,
+        tCWFstatMapVersion="lal",
+        cudaDeviceName=None,
+    ):
         """
         Parameters
         ----------
@@ -421,27 +547,34 @@ class TransientGridSearch(GridSearch):
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self.set_out_file()
-        self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
-        self.search_keys = [x+'s' for x in self.keys[2:]]
+        self.keys = ["_", "_", "F0", "F1", "F2", "Alpha", "Delta"]
+        self.search_keys = [x + "s" for x in self.keys[2:]]
         for k in self.search_keys:
             setattr(self, k, np.atleast_1d(getattr(self, k)))
 
     def inititate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         self.search = ComputeFstat(
-            tref=self.tref, sftfilepattern=self.sftfilepattern,
-            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
+            tref=self.tref,
+            sftfilepattern=self.sftfilepattern,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
             detectors=self.detectors,
             transientWindowType=self.transientWindowType,
-            t0Band=self.t0Band, tauBand=self.tauBand,
+            t0Band=self.t0Band,
+            tauBand=self.tauBand,
             tauMin=self.tauMin,
-            dt0=self.dt0, dtau=self.dtau,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            BSGL=self.BSGL, SSBprec=self.SSBprec,
+            dt0=self.dt0,
+            dtau=self.dtau,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            BSGL=self.BSGL,
+            SSBprec=self.SSBprec,
             injectSources=self.injectSources,
             assumeSqrtSX=self.assumeSqrtSX,
             tCWFstatMapVersion=self.tCWFstatMapVersion,
-            cudaDeviceName=self.cudaDeviceName)
+            cudaDeviceName=self.cudaDeviceName,
+        )
         self.search.get_det_stat = self.search.get_fullycoherent_twoF
 
     def run(self, return_data=False):
@@ -451,49 +584,62 @@ class TransientGridSearch(GridSearch):
             self.data = old_data
             return
 
-        if hasattr(self, 'search') is False:
+        if hasattr(self, "search") is False:
             self.inititate_search_object()
 
         data = []
         if self.outputTransientFstatMap:
-            tCWfilebase = os.path.splitext(self.out_file)[0] + '_tCW_'
-            logging.info('Will save per-Doppler Fstatmap' \
-                         ' results to {}*.dat'.format(tCWfilebase))
-        self.timingFstatMap = 0.
+            tCWfilebase = os.path.splitext(self.out_file)[0] + "_tCW_"
+            logging.info(
+                "Will save per-Doppler Fstatmap"
+                " results to {}*.dat".format(tCWfilebase)
+            )
+        self.timingFstatMap = 0.0
         for vals in tqdm(self.input_data):
             detstat = self.search.get_det_stat(*vals)
-            windowRange = getattr(self.search, 'windowRange', None)
-            FstatMap = getattr(self.search, 'FstatMap', None)
-            self.timingFstatMap += getattr(self.search, 'timingFstatMap', None)
+            windowRange = getattr(self.search, "windowRange", None)
+            FstatMap = getattr(self.search, "FstatMap", None)
+            self.timingFstatMap += getattr(self.search, "timingFstatMap", None)
             thisCand = list(vals) + [detstat]
-            if getattr(self, 'transientWindowType', None):
-                if self.tCWFstatMapVersion == 'lal':
+            if getattr(self, "transientWindowType", None):
+                if self.tCWFstatMapVersion == "lal":
                     F_mn = FstatMap.F_mn.data
                 else:
                     F_mn = FstatMap.F_mn
                 if self.outputTransientFstatMap:
                     # per-Doppler filename convention:
                     # freq alpha delta f1dot f2dot
-                    tCWfile = ( tCWfilebase
-                                + '%.16f_%.16f_%.16f_%.16g_%.16g.dat' %
-                                (vals[2],vals[5],vals[6],vals[3],vals[4]) )
-                    if self.tCWFstatMapVersion == 'lal':
-                        fo = lal.FileOpen(tCWfile, 'w')
-                        lalpulsar.write_transientFstatMap_to_fp (
-                            fo, FstatMap, windowRange, None )
+                    tCWfile = tCWfilebase + "%.16f_%.16f_%.16f_%.16g_%.16g.dat" % (
+                        vals[2],
+                        vals[5],
+                        vals[6],
+                        vals[3],
+                        vals[4],
+                    )
+                    if self.tCWFstatMapVersion == "lal":
+                        fo = lal.FileOpen(tCWfile, "w")
+                        lalpulsar.write_transientFstatMap_to_fp(
+                            fo, FstatMap, windowRange, None
+                        )
                         # instead of lal.FileClose(),
                         # which is not SWIG-exported:
                         del fo
                     else:
-                        self.write_F_mn ( tCWfile, F_mn, windowRange)
+                        self.write_F_mn(tCWfile, F_mn, windowRange)
                 maxidx = np.unravel_index(F_mn.argmax(), F_mn.shape)
-                thisCand += [windowRange.t0+maxidx[0]*windowRange.dt0,
-                             windowRange.tau+maxidx[1]*windowRange.dtau]
+                thisCand += [
+                    windowRange.t0 + maxidx[0] * windowRange.dt0,
+                    windowRange.tau + maxidx[1] * windowRange.dtau,
+                ]
             data.append(thisCand)
             if self.outputAtoms:
                 self.search.write_atoms_to_file(os.path.splitext(self.out_file)[0])
 
-        logging.info('Total time spent computing transient F-stat maps: {:.2f}s'.format(self.timingFstatMap))
+        logging.info(
+            "Total time spent computing transient F-stat maps: {:.2f}s".format(
+                self.timingFstatMap
+            )
+        )
 
         data = np.array(data, dtype=np.float)
         if return_data:
@@ -502,28 +648,50 @@ class TransientGridSearch(GridSearch):
             self.save_array_to_disk(data)
             self.data = data
 
-    def write_F_mn (self, tCWfile, F_mn, windowRange ):
-        with open(tCWfile, 'w') as tfp:
-            tfp.write('# t0 [s]     tau [s]     2F\n')
+    def write_F_mn(self, tCWfile, F_mn, windowRange):
+        with open(tCWfile, "w") as tfp:
+            tfp.write("# t0 [s]     tau [s]     2F\n")
             for m, F_m in enumerate(F_mn):
                 this_t0 = windowRange.t0 + m * windowRange.dt0
                 for n, this_F in enumerate(F_m):
-                    this_tau = windowRange.tau + n * windowRange.dtau;
-                    tfp.write('  %10d %10d %- 11.8g\n' % (this_t0, this_tau, 2.0*this_F))
+                    this_tau = windowRange.tau + n * windowRange.dtau
+                    tfp.write(
+                        "  %10d %10d %- 11.8g\n" % (this_t0, this_tau, 2.0 * this_F)
+                    )
 
     def __del__(self):
-        if hasattr(self,'search'):
+        if hasattr(self, "search"):
             self.search.__del__()
 
 
 class SliceGridSearch(GridSearch):
     """ Slice gridded search using ComputeFstat """
+
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0s, F1s, F2s, Alphas,
-                 Deltas, tref=None, minStartTime=None, maxStartTime=None,
-                 nsegs=1, BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, SSBprec=None, injectSources=None,
-                 input_arrays=False, assumeSqrtSX=None, Lambda0=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        sftfilepattern,
+        F0s,
+        F1s,
+        F2s,
+        Alphas,
+        Deltas,
+        tref=None,
+        minStartTime=None,
+        maxStartTime=None,
+        nsegs=1,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        SSBprec=None,
+        injectSources=None,
+        input_arrays=False,
+        assumeSqrtSX=None,
+        Lambda0=None,
+    ):
         """
         Parameters
         ----------
@@ -547,24 +715,24 @@ class SliceGridSearch(GridSearch):
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self.set_out_file()
-        self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
+        self.keys = ["_", "_", "F0", "F1", "F2", "Alpha", "Delta"]
         self.ndim = 0
         self.thetas = [F0s, F1s, Alphas, Deltas]
         self.ndim = 4
 
-        self.search_keys = ['F0', 'F1', 'Alpha', 'Delta']
+        self.search_keys = ["F0", "F1", "Alpha", "Delta"]
         if self.Lambda0 is None:
-            raise ValueError('Lambda0 undefined')
+            raise ValueError("Lambda0 undefined")
         if len(self.Lambda0) != len(self.search_keys):
             raise ValueError(
-                'Lambda0 must be of length {}'.format(len(self.search_keys)))
+                "Lambda0 must be of length {}".format(len(self.search_keys))
+            )
         self.Lambda0 = np.array(Lambda0)
 
-    def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True,
-            **kwargs):
-        lbdim = 0.5 * factor   # size of left/bottom margin
-        trdim = 0.4 * factor   # size of top/right margin
-        plotdim = factor * self.ndim + factor * (self.ndim - 1.) * whspace
+    def run(self, factor=2, max_n_ticks=4, whspace=0.07, save=True, **kwargs):
+        lbdim = 0.5 * factor  # size of left/bottom margin
+        trdim = 0.4 * factor  # size of top/right margin
+        plotdim = factor * self.ndim + factor * (self.ndim - 1.0) * whspace
         dim = lbdim + plotdim + trdim
 
         fig, axes = plt.subplots(self.ndim, self.ndim, figsize=(dim, dim))
@@ -572,27 +740,36 @@ class SliceGridSearch(GridSearch):
         # Format the figure.
         lb = lbdim / dim
         tr = (lbdim + plotdim) / dim
-        fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
-                            wspace=whspace, hspace=whspace)
+        fig.subplots_adjust(
+            left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace
+        )
 
         search = GridSearch(
-            self.label, self.outdir, self.sftfilepattern,
-            F0s=self.Lambda0[0], F1s=self.Lambda0[1], F2s=self.F2s[0],
-            Alphas=self.Lambda0[2], Deltas=self.Lambda0[3], tref=self.tref,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
+            self.label,
+            self.outdir,
+            self.sftfilepattern,
+            F0s=self.Lambda0[0],
+            F1s=self.Lambda0[1],
+            F2s=self.F2s[0],
+            Alphas=self.Lambda0[2],
+            Deltas=self.Lambda0[3],
+            tref=self.tref,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+        )
 
         for i, ikey in enumerate(self.search_keys):
-            setattr(search, ikey+'s', self.thetas[i])
-            search.label = '{}_{}'.format(self.label, ikey)
+            setattr(search, ikey + "s", self.thetas[i])
+            search.label = "{}_{}".format(self.label, ikey)
             search.set_out_file()
             search.run()
-            axes[i, i] = search.plot_1D(ikey, ax=axes[i, i], savefig=False,
-                                        x0=self.Lambda0[i]
-                                        )
-            setattr(search, ikey+'s', [self.Lambda0[i]])
+            axes[i, i] = search.plot_1D(
+                ikey, ax=axes[i, i], savefig=False, x0=self.Lambda0[i]
+            )
+            setattr(search, ikey + "s", [self.Lambda0[i]])
             axes[i, i].yaxis.tick_right()
             axes[i, i].yaxis.set_label_position("right")
-            axes[i, i].set_xlabel('')
+            axes[i, i].set_xlabel("")
 
             for j, jkey in enumerate(self.search_keys):
                 ax = axes[i, j]
@@ -603,87 +780,143 @@ class SliceGridSearch(GridSearch):
                     ax.set_yticks([])
                     continue
 
-                ax.get_shared_x_axes().join(axes[self.ndim-1, j], ax)
+                ax.get_shared_x_axes().join(axes[self.ndim - 1, j], ax)
                 if i < self.ndim - 1:
                     ax.set_xticklabels([])
                 if j < i:
-                    ax.get_shared_y_axes().join(axes[i, i-1], ax)
+                    ax.get_shared_y_axes().join(axes[i, i - 1], ax)
                     if j > 0:
                         ax.set_yticklabels([])
                 if j == i:
                     continue
 
                 ax.xaxis.set_major_locator(
-                    matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
+                    matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper")
+                )
                 ax.yaxis.set_major_locator(
-                    matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper"))
+                    matplotlib.ticker.MaxNLocator(max_n_ticks, prune="upper")
+                )
 
-                setattr(search, ikey+'s', self.thetas[i])
-                setattr(search, jkey+'s', self.thetas[j])
-                search.label = '{}_{}'.format(self.label, ikey+jkey)
+                setattr(search, ikey + "s", self.thetas[i])
+                setattr(search, jkey + "s", self.thetas[j])
+                search.label = "{}_{}".format(self.label, ikey + jkey)
                 search.set_out_file()
                 search.run()
-                ax = search.plot_2D(jkey, ikey, ax=ax, save=False,
-                                    y0=self.Lambda0[i], x0=self.Lambda0[j],
-                                    **kwargs)
-                setattr(search, ikey+'s', [self.Lambda0[i]])
-                setattr(search, jkey+'s', [self.Lambda0[j]])
-
-                ax.grid(lw=0.2, ls='--', zorder=10)
-                ax.set_xlabel('')
-                ax.set_ylabel('')
+                ax = search.plot_2D(
+                    jkey,
+                    ikey,
+                    ax=ax,
+                    save=False,
+                    y0=self.Lambda0[i],
+                    x0=self.Lambda0[j],
+                    **kwargs
+                )
+                setattr(search, ikey + "s", [self.Lambda0[i]])
+                setattr(search, jkey + "s", [self.Lambda0[j]])
+
+                ax.grid(lw=0.2, ls="--", zorder=10)
+                ax.set_xlabel("")
+                ax.set_ylabel("")
 
         for i, ikey in enumerate(self.search_keys):
-            axes[-1, i].set_xlabel(
-                self.tex_labels[ikey]+self.tex_labels0[ikey])
+            axes[-1, i].set_xlabel(self.tex_labels[ikey] + self.tex_labels0[ikey])
             if i > 0:
-                axes[i, 0].set_ylabel(
-                    self.tex_labels[ikey]+self.tex_labels0[ikey])
+                axes[i, 0].set_ylabel(self.tex_labels[ikey] + self.tex_labels0[ikey])
             axes[i, i].set_ylabel("$2\mathcal{F}$")
 
         if save:
-            fig.savefig(
-                '{}/{}_slice_projection.png'.format(self.outdir, self.label))
+            fig.savefig("{}/{}_slice_projection.png".format(self.outdir, self.label))
         else:
             return fig, axes
 
 
-class GridUniformPriorSearch():
+class GridUniformPriorSearch:
     @helper_functions.initializer
-    def __init__(self, theta_prior, NF0, NF1, label, outdir, sftfilepattern,
-                 tref, minStartTime, maxStartTime, minCoverFreq=None,
-                 maxCoverFreq=None, BSGL=False, detectors=None, nsegs=1,
-                 SSBprec=None, injectSources=None):
-        dF0 = (theta_prior['F0']['upper'] - theta_prior['F0']['lower'])/NF0
-        dF1 = (theta_prior['F1']['upper'] - theta_prior['F1']['lower'])/NF1
-        F0s = [theta_prior['F0']['lower'], theta_prior['F0']['upper'], dF0]
-        F1s = [theta_prior['F1']['lower'], theta_prior['F1']['upper'], dF1]
+    def __init__(
+        self,
+        theta_prior,
+        NF0,
+        NF1,
+        label,
+        outdir,
+        sftfilepattern,
+        tref,
+        minStartTime,
+        maxStartTime,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        BSGL=False,
+        detectors=None,
+        nsegs=1,
+        SSBprec=None,
+        injectSources=None,
+    ):
+        dF0 = (theta_prior["F0"]["upper"] - theta_prior["F0"]["lower"]) / NF0
+        dF1 = (theta_prior["F1"]["upper"] - theta_prior["F1"]["lower"]) / NF1
+        F0s = [theta_prior["F0"]["lower"], theta_prior["F0"]["upper"], dF0]
+        F1s = [theta_prior["F1"]["lower"], theta_prior["F1"]["upper"], dF1]
         self.search = GridSearch(
-            label, outdir, sftfilepattern, F0s=F0s, F1s=F1s, tref=tref,
-            Alphas=[theta_prior['Alpha']], Deltas=[theta_prior['Delta']],
-            minStartTime=minStartTime, maxStartTime=maxStartTime, BSGL=BSGL,
-            detectors=detectors, minCoverFreq=minCoverFreq,
-            injectSources=injectSources, maxCoverFreq=maxCoverFreq,
-            nsegs=nsegs, SSBprec=SSBprec)
+            label,
+            outdir,
+            sftfilepattern,
+            F0s=F0s,
+            F1s=F1s,
+            tref=tref,
+            Alphas=[theta_prior["Alpha"]],
+            Deltas=[theta_prior["Delta"]],
+            minStartTime=minStartTime,
+            maxStartTime=maxStartTime,
+            BSGL=BSGL,
+            detectors=detectors,
+            minCoverFreq=minCoverFreq,
+            injectSources=injectSources,
+            maxCoverFreq=maxCoverFreq,
+            nsegs=nsegs,
+            SSBprec=SSBprec,
+        )
 
     def run(self):
         self.search.run()
 
     def get_2D_plot(self, **kwargs):
-        return self.search.plot_2D('F0', 'F1', **kwargs)
+        return self.search.plot_2D("F0", "F1", **kwargs)
 
 
 class GridGlitchSearch(GridSearch):
     """ Grid search using the SemiCoherentGlitchSearch """
-    search_labels = ['F0s', 'F1s', 'F2s', 'Alphas', 'Deltas', 'delta_F0s',
-                     'delta_F1s', 'tglitchs']
+
+    search_labels = [
+        "F0s",
+        "F1s",
+        "F2s",
+        "Alphas",
+        "Deltas",
+        "delta_F0s",
+        "delta_F1s",
+        "tglitchs",
+    ]
 
     @helper_functions.initializer
-    def __init__(self, label, outdir='data', sftfilepattern=None, F0s=[0],
-                 F1s=[0], F2s=[0], delta_F0s=[0], delta_F1s=[0], tglitchs=None,
-                 Alphas=[0], Deltas=[0], tref=None, minStartTime=None,
-                 maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None):
+    def __init__(
+        self,
+        label,
+        outdir="data",
+        sftfilepattern=None,
+        F0s=[0],
+        F1s=[0],
+        F2s=[0],
+        delta_F0s=[0],
+        delta_F1s=[0],
+        tglitchs=None,
+        Alphas=[0],
+        Deltas=[0],
+        tref=None,
+        minStartTime=None,
+        maxStartTime=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+    ):
         """
         Run a single-glitch grid search
 
@@ -707,29 +940,60 @@ class GridGlitchSearch(GridSearch):
         self.BSGL = False
         self.input_arrays = False
         if tglitchs is None:
-            raise ValueError('You must specify `tglitchs`')
+            raise ValueError("You must specify `tglitchs`")
 
         self.search = SemiCoherentGlitchSearch(
-            label=label, outdir=outdir, sftfilepattern=self.sftfilepattern,
-            tref=tref, minStartTime=minStartTime, maxStartTime=maxStartTime,
-            minCoverFreq=minCoverFreq, maxCoverFreq=maxCoverFreq,
-            BSGL=self.BSGL)
+            label=label,
+            outdir=outdir,
+            sftfilepattern=self.sftfilepattern,
+            tref=tref,
+            minStartTime=minStartTime,
+            maxStartTime=maxStartTime,
+            minCoverFreq=minCoverFreq,
+            maxCoverFreq=maxCoverFreq,
+            BSGL=self.BSGL,
+        )
         self.search.get_det_stat = self.search.get_semicoherent_nglitch_twoF
 
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self.set_out_file()
-        self.keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta', 'delta_F0',
-                     'delta_F1', 'tglitch']
+        self.keys = [
+            "F0",
+            "F1",
+            "F2",
+            "Alpha",
+            "Delta",
+            "delta_F0",
+            "delta_F1",
+            "tglitch",
+        ]
 
 
 class SlidingWindow(GridSearch):
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0, F1, F2,
-                 Alpha, Delta, tref, minStartTime=None,
-                 maxStartTime=None, window_size=10*86400, window_delta=86400,
-                 BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, SSBprec=None, injectSources=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        sftfilepattern,
+        F0,
+        F1,
+        F2,
+        Alpha,
+        Delta,
+        tref,
+        minStartTime=None,
+        maxStartTime=None,
+        window_size=10 * 86400,
+        window_delta=86400,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        SSBprec=None,
+        injectSources=None,
+    ):
         """
         Parameters
         ----------
@@ -753,34 +1017,38 @@ class SlidingWindow(GridSearch):
 
         self.tstarts = [self.minStartTime]
         while self.tstarts[-1] + self.window_size < self.maxStartTime:
-            self.tstarts.append(self.tstarts[-1]+self.window_delta)
-        self.tmids = np.array(self.tstarts) + .5 * self.window_size
+            self.tstarts.append(self.tstarts[-1] + self.window_delta)
+        self.tmids = np.array(self.tstarts) + 0.5 * self.window_size
 
     def inititate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         self.search = ComputeFstat(
-            tref=self.tref, sftfilepattern=self.sftfilepattern,
-            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
-            detectors=self.detectors, transient=True,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            BSGL=self.BSGL, SSBprec=self.SSBprec,
-            injectSources=self.injectSources)
+            tref=self.tref,
+            sftfilepattern=self.sftfilepattern,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
+            detectors=self.detectors,
+            transient=True,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            BSGL=self.BSGL,
+            SSBprec=self.SSBprec,
+            injectSources=self.injectSources,
+        )
 
     def check_old_data_is_okay_to_use(self, out_file):
         if os.path.isfile(out_file):
             tmids, vals, errvals = np.loadtxt(out_file).T
-            if len(tmids) == len(self.tmids) and (
-                    tmids[0] == self.tmids[0]):
+            if len(tmids) == len(self.tmids) and (tmids[0] == self.tmids[0]):
                 self.vals = vals
                 self.errvals = errvals
                 return True
         return False
 
-    def run(self, key='h0', errkey='dh0'):
+    def run(self, key="h0", errkey="dh0"):
         self.key = key
         self.errkey = errkey
-        out_file = '{}/{}_{}-sliding-window.txt'.format(
-            self.outdir, self.label, key)
+        out_file = "{}/{}_{}-sliding-window.txt".format(self.outdir, self.label, key)
 
         if self.check_old_data_is_okay_to_use(out_file) is False:
             self.inititate_search_object()
@@ -788,8 +1056,15 @@ class SlidingWindow(GridSearch):
             errvals = []
             for ts in self.tstarts:
                 loudest = self.search.get_full_CFSv2_output(
-                        ts, ts+self.window_size, self.F0, self.F1, self.F2,
-                        self.Alpha, self.Delta, self.tref)
+                    ts,
+                    ts + self.window_size,
+                    self.F0,
+                    self.F1,
+                    self.F2,
+                    self.Alpha,
+                    self.Delta,
+                    self.tref,
+                )
                 vals.append(loudest[key])
                 errvals.append(loudest[errkey])
 
@@ -800,32 +1075,52 @@ class SlidingWindow(GridSearch):
     def plot_sliding_window(self, factor=1, fig=None, ax=None):
         if ax is None:
             fig, ax = plt.subplots()
-        days = (self.tmids-self.minStartTime) / 86400
-        ax.errorbar(days, self.vals*factor, yerr=self.errvals*factor)
+        days = (self.tmids - self.minStartTime) / 86400
+        ax.errorbar(days, self.vals * factor, yerr=self.errvals * factor)
         ax.set_ylabel(self.key)
         ax.set_xlabel(
-            r'Mid-point (days after $t_\mathrm{{start}}$={})'.format(
-                self.minStartTime))
+            r"Mid-point (days after $t_\mathrm{{start}}$={})".format(self.minStartTime)
+        )
         ax.set_title(
-            'Sliding window of {} days in increments of {} days'
-            .format(self.window_size/86400, self.window_delta/86400),
+            "Sliding window of {} days in increments of {} days".format(
+                self.window_size / 86400, self.window_delta / 86400
             )
+        )
 
         if fig:
-            fig.savefig('{}/{}_{}-sliding-window.png'.format(
-                self.outdir, self.label, self.key))
+            fig.savefig(
+                "{}/{}_{}-sliding-window.png".format(self.outdir, self.label, self.key)
+            )
         else:
             return ax
 
 
 class FrequencySlidingWindow(GridSearch):
     """ A sliding-window search over the Frequency """
+
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, F0s, F1, F2,
-                 Alpha, Delta, tref, minStartTime=None,
-                 maxStartTime=None, window_size=10*86400, window_delta=86400,
-                 BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, SSBprec=None, injectSources=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        sftfilepattern,
+        F0s,
+        F1,
+        F2,
+        Alpha,
+        Delta,
+        tref,
+        minStartTime=None,
+        maxStartTime=None,
+        window_size=10 * 86400,
+        window_delta=86400,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        SSBprec=None,
+        injectSources=None,
+    ):
         """
         Parameters
         ----------
@@ -844,7 +1139,7 @@ class FrequencySlidingWindow(GridSearch):
         For all other parameters, see `pyfstat.ComputeFStat` for details
         """
 
-        self.transientWindowType = 'rect'
+        self.transientWindowType = "rect"
         self.nsegs = 1
         self.t0Band = None
         self.tauBand = None
@@ -858,28 +1153,32 @@ class FrequencySlidingWindow(GridSearch):
         self.Alphas = [Alpha]
         self.Deltas = [Delta]
         self.input_arrays = False
-        self.keys = ['_', '_', 'F0', 'F1', 'F2', 'Alpha', 'Delta']
+        self.keys = ["_", "_", "F0", "F1", "F2", "Alpha", "Delta"]
 
     def inititate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         self.search = ComputeFstat(
-            tref=self.tref, sftfilepattern=self.sftfilepattern,
-            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
-            detectors=self.detectors, transientWindowType=self.transientWindowType,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            BSGL=self.BSGL, SSBprec=self.SSBprec,
-            injectSources=self.injectSources)
-        self.search.get_det_stat = (
-            self.search.get_fullycoherent_twoF)
+            tref=self.tref,
+            sftfilepattern=self.sftfilepattern,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
+            detectors=self.detectors,
+            transientWindowType=self.transientWindowType,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            BSGL=self.BSGL,
+            SSBprec=self.SSBprec,
+            injectSources=self.injectSources,
+        )
+        self.search.get_det_stat = self.search.get_fullycoherent_twoF
 
     def get_input_data_array(self):
         coord_arrays = []
         tstarts = [self.minStartTime]
         while tstarts[-1] + self.window_size < self.maxStartTime:
-            tstarts.append(tstarts[-1]+self.window_delta)
+            tstarts.append(tstarts[-1] + self.window_delta)
         coord_arrays = [tstarts]
-        for tup in (self.F0s, self.F1s, self.F2s,
-                    self.Alphas, self.Deltas):
+        for tup in (self.F0s, self.F1s, self.F2s, self.Alphas, self.Deltas):
             coord_arrays.append(self.get_array_from_tuple(tup))
 
         input_data = []
@@ -888,14 +1187,22 @@ class FrequencySlidingWindow(GridSearch):
 
         input_data = np.array(input_data)
         input_data = np.insert(
-            input_data, 1, input_data[:, 0] + self.window_size, axis=1)
+            input_data, 1, input_data[:, 0] + self.window_size, axis=1
+        )
 
         self.coord_arrays = coord_arrays
         self.input_data = np.array(input_data)
 
-    def plot_sliding_window(self, F0=None, ax=None, savefig=True,
-                            colorbar=True, timestamps=False,
-                            F0rescale=1, **kwargs):
+    def plot_sliding_window(
+        self,
+        F0=None,
+        ax=None,
+        savefig=True,
+        colorbar=True,
+        timestamps=False,
+        F0rescale=1,
+        **kwargs
+    ):
         data = self.data
         if ax is None:
             ax = plt.subplot()
@@ -904,51 +1211,71 @@ class FrequencySlidingWindow(GridSearch):
         frequencies = np.unique(data[:, 2])
         twoF = data[:, -1]
         tmids = (tstarts + tends) / 2.0
-        dts = (tmids - self.minStartTime) / 86400.
+        dts = (tmids - self.minStartTime) / 86400.0
         if F0:
             frequencies = frequencies - F0
-            ax.set_ylabel('Frequency - $f_0$ [Hz] \n $f_0={:0.2f}$'.format(F0))
+            ax.set_ylabel("Frequency - $f_0$ [Hz] \n $f_0={:0.2f}$".format(F0))
         else:
-            ax.set_ylabel('Frequency [Hz]')
+            ax.set_ylabel("Frequency [Hz]")
         twoF = twoF.reshape((len(tmids), len(frequencies)))
         Y, X = np.meshgrid(frequencies, dts)
-        pax = ax.pcolormesh(X, Y*F0rescale, twoF, **kwargs)
+        pax = ax.pcolormesh(X, Y * F0rescale, twoF, **kwargs)
         if colorbar:
             cb = plt.colorbar(pax, ax=ax)
-            cb.set_label('$2\mathcal{F}$')
+            cb.set_label("$2\mathcal{F}$")
         ax.set_xlabel(
-            r'Mid-point (days after $t_\mathrm{{start}}$={})'.format(
-                self.minStartTime))
+            r"Mid-point (days after $t_\mathrm{{start}}$={})".format(self.minStartTime)
+        )
         ax.set_title(
-            'Sliding window length = {} days in increments of {} days'
-            .format(self.window_size/86400, self.window_delta/86400),
+            "Sliding window length = {} days in increments of {} days".format(
+                self.window_size / 86400, self.window_delta / 86400
             )
+        )
         if timestamps:
             axT = ax.twiny()
-            axT.set_xlim(tmids[0]*1e-9, tmids[-1]*1e-9)
-            axT.set_xlabel('Mid-point timestamp [GPS $10^{9}$ s]')
+            axT.set_xlim(tmids[0] * 1e-9, tmids[-1] * 1e-9)
+            axT.set_xlabel("Mid-point timestamp [GPS $10^{9}$ s]")
             ax.set_title(ax.get_title(), y=1.18)
         if savefig:
             plt.tight_layout()
-            plt.savefig(
-                '{}/{}_sliding_window.png'.format(self.outdir, self.label))
+            plt.savefig("{}/{}_sliding_window.png".format(self.outdir, self.label))
         else:
             return ax
 
 
 class EarthTest(GridSearch):
     """ """
-    tex_labels = {'deltaRadius': '$\Delta R$ [m]',
-                  'phaseOffset': 'phase-offset [rad]',
-                  'deltaPspin': '$\Delta P_\mathrm{spin}$ [s]'}
+
+    tex_labels = {
+        "deltaRadius": "$\Delta R$ [m]",
+        "phaseOffset": "phase-offset [rad]",
+        "deltaPspin": "$\Delta P_\mathrm{spin}$ [s]",
+    }
 
     @helper_functions.initializer
-    def __init__(self, label, outdir, sftfilepattern, deltaRadius,
-                 phaseOffset, deltaPspin, F0, F1, F2, Alpha,
-                 Delta, tref=None, minStartTime=None, maxStartTime=None,
-                 BSGL=False, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, injectSources=None,
-                 assumeSqrtSX=None):
+    def __init__(
+        self,
+        label,
+        outdir,
+        sftfilepattern,
+        deltaRadius,
+        phaseOffset,
+        deltaPspin,
+        F0,
+        F1,
+        F2,
+        Alpha,
+        Delta,
+        tref=None,
+        minStartTime=None,
+        maxStartTime=None,
+        BSGL=False,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+    ):
         """
         Parameters
         ----------
@@ -979,18 +1306,21 @@ class EarthTest(GridSearch):
         self.duration = maxStartTime - minStartTime
         self.deltaRadius = np.atleast_1d(deltaRadius)
         self.phaseOffset = np.atleast_1d(phaseOffset)
-        self.phaseOffset = self.phaseOffset + 1e-12  # Hack to stop cached data being used
+        self.phaseOffset = (
+            self.phaseOffset + 1e-12
+        )  # Hack to stop cached data being used
         self.deltaPspin = np.atleast_1d(deltaPspin)
         self.set_out_file()
         self.SSBprec = lalpulsar.SSBPREC_RELATIVISTIC
-        self.keys = ['deltaRadius', 'phaseOffset', 'deltaPspin']
+        self.keys = ["deltaRadius", "phaseOffset", "deltaPspin"]
 
         self.prior_widths = [
-            np.max(self.deltaRadius)-np.min(self.deltaRadius),
-            np.max(self.phaseOffset)-np.min(self.phaseOffset),
-            np.max(self.deltaPspin)-np.min(self.deltaPspin)]
+            np.max(self.deltaRadius) - np.min(self.deltaRadius),
+            np.max(self.phaseOffset) - np.min(self.phaseOffset),
+            np.max(self.deltaPspin) - np.min(self.deltaPspin),
+        ]
 
-        if hasattr(self, 'search') is False:
+        if hasattr(self, "search") is False:
             self.inititate_search_object()
 
     def get_input_data_array(self):
@@ -998,19 +1328,27 @@ class EarthTest(GridSearch):
         coord_arrays = [self.deltaRadius, self.phaseOffset, self.deltaPspin]
         input_data = []
         for vals in itertools.product(*coord_arrays):
-                input_data.append(vals)
+            input_data.append(vals)
         self.input_data = np.array(input_data)
         self.coord_arrays = coord_arrays
 
     def run_special(self):
-        vals = [self.minStartTime, self.maxStartTime, self.F0, self.F1,
-                self.F2, self.Alpha, self.Delta]
-        self.special_data = {'zero': [0, 0, 0]}
+        vals = [
+            self.minStartTime,
+            self.maxStartTime,
+            self.F0,
+            self.F1,
+            self.F2,
+            self.Alpha,
+            self.Delta,
+        ]
+        self.special_data = {"zero": [0, 0, 0]}
         for key, (dR, dphi, dP) in self.special_data.items():
-            rescaleRadius = (1 + dR / lal.REARTH_SI)
-            rescalePeriod = (1 + dP / lal.DAYSID_SI)
+            rescaleRadius = 1 + dR / lal.REARTH_SI
+            rescalePeriod = 1 + dP / lal.DAYSID_SI
             lalpulsar.BarycenterModifyEarthRotation(
-                rescaleRadius, dphi, rescalePeriod, self.tref)
+                rescaleRadius, dphi, rescalePeriod, self.tref
+            )
             FS = self.search.get_det_stat(*vals)
             self.special_data[key] = list([dR, dphi, dP]) + [FS]
 
@@ -1023,19 +1361,27 @@ class EarthTest(GridSearch):
             return
 
         data = []
-        vals = [self.minStartTime, self.maxStartTime, self.F0, self.F1,
-                self.F2, self.Alpha, self.Delta]
+        vals = [
+            self.minStartTime,
+            self.maxStartTime,
+            self.F0,
+            self.F1,
+            self.F2,
+            self.Alpha,
+            self.Delta,
+        ]
         for (dR, dphi, dP) in tqdm(self.input_data):
-            rescaleRadius = (1 + dR / lal.REARTH_SI)
-            rescalePeriod = (1 + dP / lal.DAYSID_SI)
+            rescaleRadius = 1 + dR / lal.REARTH_SI
+            rescalePeriod = 1 + dP / lal.DAYSID_SI
             lalpulsar.BarycenterModifyEarthRotation(
-                rescaleRadius, dphi, rescalePeriod, self.tref)
+                rescaleRadius, dphi, rescalePeriod, self.tref
+            )
             FS = self.search.get_det_stat(*vals)
             data.append(list([dR, dphi, dP]) + [FS])
 
         data = np.array(data, dtype=np.float)
-        logging.info('Saving data to {}'.format(self.out_file))
-        np.savetxt(self.out_file, data, delimiter=' ')
+        logging.info("Saving data to {}".format(self.out_file))
+        np.savetxt(self.out_file, data, delimiter=" ")
         self.data = data
 
     def marginalised_bayes_factor(self, prior_widths=None):
@@ -1049,79 +1395,110 @@ class EarthTest(GridSearch):
         for i, x in enumerate(params[::-1]):
             if len(x) > 1:
                 dx = x[1] - x[0]
-                F = logsumexp(F, axis=-1)+np.log(dx)-np.log(prior_widths[-1-i])
+                F = logsumexp(F, axis=-1) + np.log(dx) - np.log(prior_widths[-1 - i])
             else:
                 F = np.squeeze(F, axis=-1)
         marginalised_F = np.atleast_1d(F)[0]
-        F_at_zero = self.special_data['zero'][-1]/2.0
+        F_at_zero = self.special_data["zero"][-1] / 2.0
 
         max_idx = np.argmax(self.data[:, -1])
-        max_F = self.data[max_idx, -1]/2.0
+        max_F = self.data[max_idx, -1] / 2.0
         max_F_params = self.data[max_idx, :-1]
-        logging.info('F at zero = {:.1f}, marginalised_F = {:.1f},'
-                     ' max_F = {:.1f} ({})'.format(
-                         F_at_zero, marginalised_F, max_F, max_F_params))
+        logging.info(
+            "F at zero = {:.1f}, marginalised_F = {:.1f},"
+            " max_F = {:.1f} ({})".format(
+                F_at_zero, marginalised_F, max_F, max_F_params
+            )
+        )
         return F_at_zero - marginalised_F, (F_at_zero - max_F) / F_at_zero
 
-    def plot_corner(self, prior_widths=None, fig=None, axes=None,
-                    projection='log_mean'):
+    def plot_corner(
+        self, prior_widths=None, fig=None, axes=None, projection="log_mean"
+    ):
         Bsa, FmaxMismatch = self.marginalised_bayes_factor(prior_widths)
 
         data = self.data[:, -1].reshape(
-            (len(self.deltaRadius), len(self.phaseOffset),
-             len(self.deltaPspin)))
-        xyz = [self.deltaRadius/lal.REARTH_SI, self.phaseOffset/(np.pi),
-               self.deltaPspin/60.]
-        labels = [r'$\frac{\Delta R}{R_\mathrm{Earth}}$',
-                  r'$\frac{\Delta \phi}{\pi}$',
-                  r'$\Delta P_\mathrm{spin}$ [min]',
-                  r'$2\mathcal{F}$']
+            (len(self.deltaRadius), len(self.phaseOffset), len(self.deltaPspin))
+        )
+        xyz = [
+            self.deltaRadius / lal.REARTH_SI,
+            self.phaseOffset / (np.pi),
+            self.deltaPspin / 60.0,
+        ]
+        labels = [
+            r"$\frac{\Delta R}{R_\mathrm{Earth}}$",
+            r"$\frac{\Delta \phi}{\pi}$",
+            r"$\Delta P_\mathrm{spin}$ [min]",
+            r"$2\mathcal{F}$",
+        ]
 
         try:
             from gridcorner import gridcorner
         except ImportError:
             raise ImportError(
                 "Python module 'gridcorner' not found, please install from "
-                "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner")
+                "https://gitlab.aei.uni-hannover.de/GregAshton/gridcorner"
+            )
 
-        fig, axes = gridcorner(data, xyz, projection=projection, factor=1.6,
-                               labels=labels)
-        axes[-1][-1].axvline((lal.DAYJUL_SI - lal.DAYSID_SI)/60.0, color='C3')
+        fig, axes = gridcorner(
+            data, xyz, projection=projection, factor=1.6, labels=labels
+        )
+        axes[-1][-1].axvline((lal.DAYJUL_SI - lal.DAYSID_SI) / 60.0, color="C3")
         plt.suptitle(
-            'T={:.1f} days, $f$={:.2f} Hz, $\log\mathcal{{B}}_{{S/A}}$={:.1f},'
-            r' $\frac{{\mathcal{{F}}_0-\mathcal{{F}}_\mathrm{{max}}}}'
-            r'{{\mathcal{{F}}_0}}={:.1e}$'
-            .format(self.duration/86400, self.F0, Bsa, FmaxMismatch), y=0.99,
-            size=14)
-        fig.savefig('{}/{}_projection_matrix.png'.format(
-            self.outdir, self.label))
+            "T={:.1f} days, $f$={:.2f} Hz, $\log\mathcal{{B}}_{{S/A}}$={:.1f},"
+            r" $\frac{{\mathcal{{F}}_0-\mathcal{{F}}_\mathrm{{max}}}}"
+            r"{{\mathcal{{F}}_0}}={:.1e}$".format(
+                self.duration / 86400, self.F0, Bsa, FmaxMismatch
+            ),
+            y=0.99,
+            size=14,
+        )
+        fig.savefig("{}/{}_projection_matrix.png".format(self.outdir, self.label))
 
     def plot(self, key, prior_widths=None):
         Bsa, FmaxMismatch = self.marginalised_bayes_factor(prior_widths)
 
-        rescales_defaults = {'deltaRadius': 1/lal.REARTH_SI,
-                             'phaseOffset': 1/np.pi,
-                             'deltaPspin': 1}
-        labels = {'deltaRadius': r'$\frac{\Delta R}{R_\mathrm{Earth}}$',
-                  'phaseOffset': r'$\frac{\Delta \phi}{\pi}$',
-                  'deltaPspin': r'$\Delta P_\mathrm{spin}$ [s]'
-                  }
-
-        fig, ax = self.plot_1D(key, xrescale=rescales_defaults[key],
-                               xlabel=labels[key], savefig=False)
+        rescales_defaults = {
+            "deltaRadius": 1 / lal.REARTH_SI,
+            "phaseOffset": 1 / np.pi,
+            "deltaPspin": 1,
+        }
+        labels = {
+            "deltaRadius": r"$\frac{\Delta R}{R_\mathrm{Earth}}$",
+            "phaseOffset": r"$\frac{\Delta \phi}{\pi}$",
+            "deltaPspin": r"$\Delta P_\mathrm{spin}$ [s]",
+        }
+
+        fig, ax = self.plot_1D(
+            key, xrescale=rescales_defaults[key], xlabel=labels[key], savefig=False
+        )
         ax.set_title(
-            'T={} days, $f$={} Hz, $\log\mathcal{{B}}_{{S/A}}$={:.1f}'
-            .format(self.duration/86400, self.F0, Bsa))
+            "T={} days, $f$={} Hz, $\log\mathcal{{B}}_{{S/A}}$={:.1f}".format(
+                self.duration / 86400, self.F0, Bsa
+            )
+        )
         fig.tight_layout()
-        fig.savefig('{}/{}_1D.png'.format(self.outdir, self.label))
+        fig.savefig("{}/{}_1D.png".format(self.outdir, self.label))
 
 
 class DMoff_NO_SPIN(GridSearch):
     """ DMoff test using SSBPREC_NO_SPIN """
+
     @helper_functions.initializer
-    def __init__(self, par, label, outdir, sftfilepattern, minStartTime=None,
-                 maxStartTime=None, minCoverFreq=None, maxCoverFreq=None,
-                 detectors=None, injectSources=None, assumeSqrtSX=None):
+    def __init__(
+        self,
+        par,
+        label,
+        outdir,
+        sftfilepattern,
+        minStartTime=None,
+        maxStartTime=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        detectors=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+    ):
         """
         Parameters
         ----------
@@ -1147,22 +1524,21 @@ class DMoff_NO_SPIN(GridSearch):
         elif type(par) == str and os.path.isfile(par):
             self.par = read_par(filename=par)
         else:
-            raise ValueError('The .par file does not exist')
+            raise ValueError("The .par file does not exist")
 
         self.nsegs = 1
         self.BSGL = False
 
-        self.tref = self.par['tref']
-        self.F1s = [self.par.get('F1', 0)]
-        self.F2s = [self.par.get('F2', 0)]
-        self.Alphas = [self.par['Alpha']]
-        self.Deltas = [self.par['Delta']]
+        self.tref = self.par["tref"]
+        self.F1s = [self.par.get("F1", 0)]
+        self.F2s = [self.par.get("F2", 0)]
+        self.Alphas = [self.par["Alpha"]]
+        self.Deltas = [self.par["Delta"]]
         self.Re = 6.371e6
         self.c = 2.998e8
-        a0 = self.Re/self.c  # *np.cos(self.par['Delta'])
-        self.m0 = np.max([4, int(np.ceil(2*np.pi*self.par['F0']*a0))])
-        logging.info(
-            'Setting up DMoff_NO_SPIN search with m0 = {}'.format(self.m0))
+        a0 = self.Re / self.c  # *np.cos(self.par['Delta'])
+        self.m0 = np.max([4, int(np.ceil(2 * np.pi * self.par["F0"] * a0))])
+        logging.info("Setting up DMoff_NO_SPIN search with m0 = {}".format(self.m0))
 
     def get_results(self):
         """ Compute the three summed detection statistics
@@ -1173,21 +1549,23 @@ class DMoff_NO_SPIN(GridSearch):
 
         """
         self.SSBprec = lalpulsar.SSBPREC_RELATIVISTIC
-        self.set_out_file('SSBPREC_RELATIVISTIC')
-        self.F0s = [self.par['F0']+j/lal.DAYSID_SI for j in range(-4, 5)]
+        self.set_out_file("SSBPREC_RELATIVISTIC")
+        self.F0s = [self.par["F0"] + j / lal.DAYSID_SI for j in range(-4, 5)]
         self.run()
         twoF_SUM = np.sum(self.data[:, -1])
 
         self.SSBprec = lalpulsar.SSBPREC_NO_SPIN
-        self.set_out_file('SSBPREC_NO_SPIN')
-        self.F0s = [self.par['F0']+j/lal.DAYSID_SI
-                    for j in range(-self.m0, self.m0+1)]
+        self.set_out_file("SSBPREC_NO_SPIN")
+        self.F0s = [
+            self.par["F0"] + j / lal.DAYSID_SI for j in range(-self.m0, self.m0 + 1)
+        ]
         self.run()
         twoFstar_SUM = np.sum(self.data[:, -1])
 
-        self.set_out_file('SSBPREC_NO_SPIN_TERRESTRIAL')
-        self.F0s = [self.par['F0']+j/lal.DAYJUL_SI
-                    for j in range(-self.m0, self.m0+1)]
+        self.set_out_file("SSBPREC_NO_SPIN_TERRESTRIAL")
+        self.F0s = [
+            self.par["F0"] + j / lal.DAYJUL_SI for j in range(-self.m0, self.m0 + 1)
+        ]
         self.run()
         twoFstar_SUM_terrestrial = np.sum(self.data[:, -1])
 
diff --git a/pyfstat/helper_functions.py b/pyfstat/helper_functions.py
index 9bbc2bfb58f20727d2b2895fc6baf03639168bf5..8943b3d97db0068fee7f6d81fc520ccfdf4e3c91 100644
--- a/pyfstat/helper_functions.py
+++ b/pyfstat/helper_functions.py
@@ -16,64 +16,97 @@ import lal
 import lalpulsar
 
 # workaround for matplotlib on X-less remote logins
-if 'DISPLAY' in os.environ:
+if "DISPLAY" in os.environ:
     import matplotlib.pyplot as plt
 else:
-    logging.info('No $DISPLAY environment variable found, so importing \
-                  matplotlib.pyplot with non-interactive "Agg" backend.')
+    logging.info(
+        'No $DISPLAY environment variable found, so importing \
+                  matplotlib.pyplot with non-interactive "Agg" backend.'
+    )
     import matplotlib
-    matplotlib.use('Agg')
+
+    matplotlib.use("Agg")
     import matplotlib.pyplot as plt
 
+
 def set_up_optional_tqdm():
     try:
         from tqdm import tqdm
     except ImportError:
+
         def tqdm(x, *args, **kwargs):
             return x
+
     return tqdm
 
 
 def set_up_matplotlib_defaults():
-    plt.switch_backend('Agg')
-    plt.rcParams['text.usetex'] = True
-    plt.rcParams['axes.formatter.useoffset'] = False
+    plt.switch_backend("Agg")
+    plt.rcParams["text.usetex"] = True
+    plt.rcParams["axes.formatter.useoffset"] = False
 
 
 def set_up_command_line_arguments():
     parser = argparse.ArgumentParser()
-    parser.add_argument("-v", "--verbose", action="store_true",
-                        help="Increase output verbosity [logging.DEBUG]")
-    parser.add_argument("-q", "--quite", action="store_true",
-                        help="Decrease output verbosity [logging.WARNING]")
-    parser.add_argument("--no-interactive", help="Don't use interactive",
-                        action="store_true")
-    parser.add_argument("-c", "--clean", action="store_true",
-                        help="Force clean data, never use cached data")
+    parser.add_argument(
+        "-v",
+        "--verbose",
+        action="store_true",
+        help="Increase output verbosity [logging.DEBUG]",
+    )
+    parser.add_argument(
+        "-q",
+        "--quite",
+        action="store_true",
+        help="Decrease output verbosity [logging.WARNING]",
+    )
+    parser.add_argument(
+        "--no-interactive", help="Don't use interactive", action="store_true"
+    )
+    parser.add_argument(
+        "-c",
+        "--clean",
+        action="store_true",
+        help="Force clean data, never use cached data",
+    )
     fu_parser = parser.add_argument_group(
-        'follow-up options', 'Options related to MCMCFollowUpSearch')
-    fu_parser.add_argument('-s', "--setup-only", action="store_true",
-                           help="Only generate the setup file, don't run")
+        "follow-up options", "Options related to MCMCFollowUpSearch"
+    )
+    fu_parser.add_argument(
+        "-s",
+        "--setup-only",
+        action="store_true",
+        help="Only generate the setup file, don't run",
+    )
     fu_parser.add_argument(
-        "--no-template-counting", action="store_true",
-        help="No counting of templates, useful if the setup is predefined")
+        "--no-template-counting",
+        action="store_true",
+        help="No counting of templates, useful if the setup is predefined",
+    )
     parser.add_argument(
-        '-N', type=int, default=3, metavar='N',
-        help="Number of threads to use when running in parallel")
-    parser.add_argument('unittest_args', nargs='*')
+        "-N",
+        type=int,
+        default=3,
+        metavar="N",
+        help="Number of threads to use when running in parallel",
+    )
+    parser.add_argument("unittest_args", nargs="*")
     args, unknown = parser.parse_known_args()
     sys.argv[1:] = args.unittest_args
 
     if args.quite or args.no_interactive:
+
         def tqdm(x, *args, **kwargs):
             return x
+
     else:
         tqdm = set_up_optional_tqdm()
 
     logger = logging.getLogger()
     stream_handler = logging.StreamHandler()
-    stream_handler.setFormatter(logging.Formatter(
-        '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
+    stream_handler.setFormatter(
+        logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s", datefmt="%H:%M")
+    )
 
     if args.quite:
         logger.setLevel(logging.WARNING)
@@ -91,39 +124,45 @@ def set_up_command_line_arguments():
 
 def get_ephemeris_files():
     """ Returns the earth_ephem and sun_ephem """
-    config_file = os.path.expanduser('~')+'/.pyfstat.conf'
-    env_var = 'LALPULSAR_DATADIR'
-    please = 'Please provide the ephemerides paths when initialising searches.'
+    config_file = os.path.expanduser("~") + "/.pyfstat.conf"
+    env_var = "LALPULSAR_DATADIR"
+    please = "Please provide the ephemerides paths when initialising searches."
     if os.path.isfile(config_file):
         d = {}
-        with open(config_file, 'r') as f:
+        with open(config_file, "r") as f:
             for line in f:
-                k, v = line.split('=')
-                k = k.replace(' ', '')
-                for item in [' ', "'", '"', '\n']:
-                    v = v.replace(item, '')
+                k, v = line.split("=")
+                k = k.replace(" ", "")
+                for item in [" ", "'", '"', "\n"]:
+                    v = v.replace(item, "")
                 d[k] = v
         try:
-            earth_ephem = d['earth_ephem']
-            sun_ephem = d['sun_ephem']
+            earth_ephem = d["earth_ephem"]
+            sun_ephem = d["sun_ephem"]
         except:
-            logging.warning('No [earth/sun]_ephem found in '+config_file+'. '+please)
+            logging.warning(
+                "No [earth/sun]_ephem found in " + config_file + ". " + please
+            )
             earth_ephem = None
             sun_ephem = None
     elif env_var in list(os.environ.keys()):
-        earth_ephem = os.path.join(os.environ[env_var],'earth00-40-DE421.dat.gz')
-        sun_ephem = os.path.join(os.environ[env_var],'sun00-40-DE421.dat.gz')
-        if not ( os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem) ):
-            earth_ephem = os.path.join(os.environ[env_var],'earth00-19-DE421.dat.gz')
-            sun_ephem = os.path.join(os.environ[env_var],'sun00-19-DE421.dat.gz')
-            if not ( os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem) ):
-                logging.warning('No [earth/sun]00-[19/40]-DE421 ephemerides '
-                                'found in the '+os.environ[env_var]+' directory. '+please)
+        earth_ephem = os.path.join(os.environ[env_var], "earth00-40-DE421.dat.gz")
+        sun_ephem = os.path.join(os.environ[env_var], "sun00-40-DE421.dat.gz")
+        if not (os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem)):
+            earth_ephem = os.path.join(os.environ[env_var], "earth00-19-DE421.dat.gz")
+            sun_ephem = os.path.join(os.environ[env_var], "sun00-19-DE421.dat.gz")
+            if not (os.path.isfile(earth_ephem) and os.path.isfile(sun_ephem)):
+                logging.warning(
+                    "No [earth/sun]00-[19/40]-DE421 ephemerides "
+                    "found in the " + os.environ[env_var] + " directory. " + please
+                )
                 earth_ephem = None
                 sun_ephem = None
     else:
-        logging.warning('No '+config_file+' file or $'+env_var+' environment '
-                        'variable found. '+please)
+        logging.warning(
+            "No " + config_file + " file or $" + env_var + " environment "
+            "variable found. " + please
+        )
         earth_ephem = None
         sun_ephem = None
     return earth_ephem, sun_ephem
@@ -133,7 +172,7 @@ def round_to_n(x, n):
     if not x:
         return 0
     power = -int(np.floor(np.log10(abs(x)))) + (n - 1)
-    factor = (10 ** power)
+    factor = 10 ** power
     return round(x * factor) / factor
 
 
@@ -147,10 +186,10 @@ def texify_float(x, d=2):
         return str(x)
     else:
         power = int(np.floor(np.log10(abs(x))))
-        stem = np.round(x / 10**power, d)
+        stem = np.round(x / 10 ** power, d)
         if d == 1:
             stem = int(stem)
-        return r'${}{{\times}}10^{{{}}}$'.format(stem, power)
+        return r"${}{{\times}}10^{{{}}}$".format(stem, power)
 
 
 def initializer(func):
@@ -176,54 +215,58 @@ def get_peak_values(frequencies, twoF, threshold_2F, F0=None, F0range=None):
         cut_idxs = np.abs(frequencies - F0) < F0range
         frequencies = frequencies[cut_idxs]
         twoF = twoF[cut_idxs]
-    idxs = peakutils.indexes(twoF, thres=1.*threshold_2F/np.max(twoF))
+    idxs = peakutils.indexes(twoF, thres=1.0 * threshold_2F / np.max(twoF))
     F0maxs = frequencies[idxs]
     twoFmaxs = twoF[idxs]
     freq_err = frequencies[1] - frequencies[0]
-    return F0maxs, twoFmaxs, freq_err*np.ones(len(idxs))
+    return F0maxs, twoFmaxs, freq_err * np.ones(len(idxs))
 
 
 def get_comb_values(F0, frequencies, twoF, period, N=4):
-    if period == 'sidereal':
-        period = 23*60*60 + 56*60 + 4.0616
-    elif period == 'terrestrial':
+    if period == "sidereal":
+        period = 23 * 60 * 60 + 56 * 60 + 4.0616
+    elif period == "terrestrial":
         period = 86400
     freq_err = frequencies[1] - frequencies[0]
-    comb_frequencies = [n*1/period for n in range(-N, N+1)]
-    comb_idxs = [np.argmin(np.abs(frequencies-F0-F)) for F in comb_frequencies]
-    return comb_frequencies, twoF[comb_idxs], freq_err*np.ones(len(comb_idxs))
+    comb_frequencies = [n * 1 / period for n in range(-N, N + 1)]
+    comb_idxs = [np.argmin(np.abs(frequencies - F0 - F)) for F in comb_frequencies]
+    return comb_frequencies, twoF[comb_idxs], freq_err * np.ones(len(comb_idxs))
 
 
 def compute_P_twoFstarcheck(twoFstarcheck, twoFcheck, M0, plot=False):
     """ Returns the unnormalised pdf of twoFstarcheck given twoFcheck """
-    upper = 4+twoFstarcheck + 0.5*(2*(4*M0+2*twoFcheck))
+    upper = 4 + twoFstarcheck + 0.5 * (2 * (4 * M0 + 2 * twoFcheck))
     rho2starcheck = np.linspace(1e-1, upper, 500)
-    integrand = (ncx2.pdf(twoFstarcheck, 4*M0, rho2starcheck)
-                 * ncx2.pdf(twoFcheck, 4, rho2starcheck))
+    integrand = ncx2.pdf(twoFstarcheck, 4 * M0, rho2starcheck) * ncx2.pdf(
+        twoFcheck, 4, rho2starcheck
+    )
     if plot:
         fig, ax = plt.subplots()
         ax.plot(rho2starcheck, integrand)
-        fig.savefig('test')
+        fig.savefig("test")
     return np.trapz(integrand, rho2starcheck)
 
 
 def compute_pstar(twoFcheck_obs, twoFstarcheck_obs, m0, plot=False):
-    M0 = 2*m0 + 1
-    upper = 4+twoFcheck_obs + (2*(4*M0+2*twoFcheck_obs))
+    M0 = 2 * m0 + 1
+    upper = 4 + twoFcheck_obs + (2 * (4 * M0 + 2 * twoFcheck_obs))
     twoFstarcheck_vals = np.linspace(1e-1, upper, 500)
     P_twoFstarcheck = np.array(
-        [compute_P_twoFstarcheck(twoFstarcheck, twoFcheck_obs, M0)
-         for twoFstarcheck in twoFstarcheck_vals])
+        [
+            compute_P_twoFstarcheck(twoFstarcheck, twoFcheck_obs, M0)
+            for twoFstarcheck in twoFstarcheck_vals
+        ]
+    )
     C = np.trapz(P_twoFstarcheck, twoFstarcheck_vals)
     idx = np.argmin(np.abs(twoFstarcheck_vals - twoFstarcheck_obs))
     if plot:
         fig, ax = plt.subplots()
         ax.plot(twoFstarcheck_vals, P_twoFstarcheck)
-        ax.fill_between(twoFstarcheck_vals[:idx+1], 0, P_twoFstarcheck[:idx+1])
+        ax.fill_between(twoFstarcheck_vals[: idx + 1], 0, P_twoFstarcheck[: idx + 1])
         ax.axvline(twoFstarcheck_vals[idx])
-        fig.savefig('test')
-    pstar_l = np.trapz(P_twoFstarcheck[:idx+1]/C, twoFstarcheck_vals[:idx+1])
-    return 2*np.min([pstar_l, 1-pstar_l])
+        fig.savefig("test")
+    pstar_l = np.trapz(P_twoFstarcheck[: idx + 1] / C, twoFstarcheck_vals[: idx + 1])
+    return 2 * np.min([pstar_l, 1 - pstar_l])
 
 
 def run_commandline(cl, log_level=20, raise_error=True, return_output=True):
@@ -239,28 +282,28 @@ def run_commandline(cl, log_level=20, raise_error=True, return_output=True):
 
     """
 
-    logging.log(log_level, 'Now executing: ' + cl)
+    logging.log(log_level, "Now executing: " + cl)
     if return_output:
         try:
-            out = subprocess.check_output(cl,                       # what to run
-                                          stderr=subprocess.STDOUT, # catch errors
-                                          shell=True,               # proper environment etc
-                                          universal_newlines=True,  # properly display linebreaks in error/output printing
-                                         )
+            out = subprocess.check_output(
+                cl,  # what to run
+                stderr=subprocess.STDOUT,  # catch errors
+                shell=True,  # proper environment etc
+                universal_newlines=True,  # properly display linebreaks in error/output printing
+            )
         except subprocess.CalledProcessError as e:
-            logging.log(log_level, 'Execution failed: {}'.format(e.output))
+            logging.log(log_level, "Execution failed: {}".format(e.output))
             if raise_error:
                 raise
             else:
                 out = 0
-        os.system('\n')
-        return(out)
+        os.system("\n")
+        return out
     else:
         process = subprocess.Popen(cl, shell=True)
         process.communicate()
 
 
-
 def convert_array_to_gsl_matrix(array):
     gsl_matrix = lal.gsl_matrix(*array.shape)
     gsl_matrix.data = array
@@ -270,16 +313,15 @@ def convert_array_to_gsl_matrix(array):
 def get_sft_array(sftfilepattern, data_duration, F0, dF0):
     """ Return the raw data from a set of sfts """
 
-    SFTCatalog = lalpulsar.SFTdataFind(
-        sftfilepattern, lalpulsar.SFTConstraints())
-    MultiSFTs = lalpulsar.LoadMultiSFTs(SFTCatalog, F0-dF0, F0+dF0)
+    SFTCatalog = lalpulsar.SFTdataFind(sftfilepattern, lalpulsar.SFTConstraints())
+    MultiSFTs = lalpulsar.LoadMultiSFTs(SFTCatalog, F0 - dF0, F0 + dF0)
     SFTs = MultiSFTs.data[0]
     data = []
     for sft in SFTs.data:
         data.append(np.abs(sft.data.data))
     data = np.array(data).T
     n, nsfts = data.shape
-    freqs = np.linspace(sft.f0, sft.f0+n*sft.deltaF, n)
+    freqs = np.linspace(sft.f0, sft.f0 + n * sft.deltaF, n)
     times = np.linspace(0, data_duration, nsfts)
 
     return times, freqs, data
@@ -318,10 +360,11 @@ def get_covering_band(tref, tstart, tend, F0, F1, F2):
     return lalpulsar.CWSignalCoveringBand(tstart, tend, psr, 0, 0, 0)
 
 
-def twoFDMoffThreshold(twoFon, knee=400, twoFDMoffthreshold_below_threshold=62,
-                       prefactor=0.9, offset=0.5):
+def twoFDMoffThreshold(
+    twoFon, knee=400, twoFDMoffthreshold_below_threshold=62, prefactor=0.9, offset=0.5
+):
     """ Calculation of the 2F_DMoff threshold, see Eq 2 of arXiv:1707.5286 """
     if twoFon <= knee:
         return twoFDMoffthreshold_below_threshold
     else:
-        return 10**(prefactor*np.log10(twoFon-offset))
+        return 10 ** (prefactor * np.log10(twoFon - offset))
diff --git a/pyfstat/injection_helper_functions.py b/pyfstat/injection_helper_functions.py
index ee038595e1c458232a7aa0a4fac27be2486c6de0..cba3c253b39cc7d75f7bae39fc2b189b4ee1c983 100644
--- a/pyfstat/injection_helper_functions.py
+++ b/pyfstat/injection_helper_functions.py
@@ -5,12 +5,13 @@
 
 import numpy as np
 import logging
+
 try:
     from astropy import units as u
     from astropy.coordinates import SkyCoord
     from astropy.time import Time
 except ImportError:
-    logging.warning('Python module astropy not installed')
+    logging.warning("Python module astropy not installed")
 import lal
 
 # Assume Earth goes around Sun in a non-wobbling circle at constant speed;
@@ -20,22 +21,20 @@ import lal
 
 
 def _eqToEcl(alpha, delta):
-    source = SkyCoord(alpha*u.radian, delta*u.radian, frame='gcrs')
-    out = source.transform_to('geocentrictrueecliptic')
+    source = SkyCoord(alpha * u.radian, delta * u.radian, frame="gcrs")
+    out = source.transform_to("geocentrictrueecliptic")
     return np.array([out.lon.radian, out.lat.radian])
 
 
 def _eclToEq(lon, lat):
-    source = SkyCoord(lon*u.radian, lat*u.radian,
-                      frame='geocentrictrueecliptic')
-    out = source.transform_to('gcrs')
+    source = SkyCoord(lon * u.radian, lat * u.radian, frame="geocentrictrueecliptic")
+    out = source.transform_to("gcrs")
     return np.array([out.ra.radian, out.dec.radian])
 
 
-def _calcDopplerWings(
-        s_freq, s_alpha, s_delta, lonStart, lonStop, numTimes=100):
+def _calcDopplerWings(s_freq, s_alpha, s_delta, lonStart, lonStop, numTimes=100):
     e_longitudes = np.linspace(lonStart, lonStop, numTimes)
-    v_over_c = 2*np.pi*lal.AU_SI/lal.YRSID_SI/lal.C_SI
+    v_over_c = 2 * np.pi * lal.AU_SI / lal.YRSID_SI / lal.C_SI
     s_lon, s_lat = _eqToEcl(s_alpha, s_delta)
 
     vertical = s_lat
@@ -45,8 +44,8 @@ def _calcDopplerWings(
     F0min, F0max = np.amin(dopplerShifts), np.amax(dopplerShifts)
 
     # Add twice the spin-modulation
-    SpinModulationMax = 2*np.pi*lal.REARTH_SI/lal.DAYSID_SI/lal.C_SI * s_freq
-    return F0min - 2*SpinModulationMax, F0max + 2*SpinModulationMax
+    SpinModulationMax = 2 * np.pi * lal.REARTH_SI / lal.DAYSID_SI / lal.C_SI * s_freq
+    return F0min - 2 * SpinModulationMax, F0max + 2 * SpinModulationMax
 
 
 def _calcSpindownWings(freq, fdot, minStartTime, maxStartTime):
@@ -54,8 +53,7 @@ def _calcSpindownWings(freq, fdot, minStartTime, maxStartTime):
     return 0.5 * timespan * np.abs(fdot) * np.array([-1, 1])
 
 
-def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime,
-                                  maxStartTime):
+def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime, maxStartTime):
     """ Calculate the frequency range that a signal will occupy
 
     Parameters
@@ -78,14 +76,14 @@ def get_frequency_range_of_signal(F0, F1, Alpha, Delta, minStartTime,
     YEAR_IN_DAYS = lal.YRSID_SI / lal.DAYSID_SI
     tEquinox = 79
 
-    minStartTime_t = Time(minStartTime, format='gps').to_datetime().timetuple()
-    maxStartTime_t = Time(maxStartTime, format='gps').to_datetime().timetuple()
+    minStartTime_t = Time(minStartTime, format="gps").to_datetime().timetuple()
+    maxStartTime_t = Time(maxStartTime, format="gps").to_datetime().timetuple()
     tStart_days = minStartTime_t.tm_yday - tEquinox
     tStop_days = maxStartTime_t.tm_yday - tEquinox
-    tStop_days += (maxStartTime_t.tm_year-minStartTime_t.tm_year)*YEAR_IN_DAYS
+    tStop_days += (maxStartTime_t.tm_year - minStartTime_t.tm_year) * YEAR_IN_DAYS
 
-    lonStart = 2*np.pi*tStart_days/YEAR_IN_DAYS - np.pi
-    lonStop = 2*np.pi*tStop_days/YEAR_IN_DAYS - np.pi
+    lonStart = 2 * np.pi * tStart_days / YEAR_IN_DAYS - np.pi
+    lonStop = 2 * np.pi * tStop_days / YEAR_IN_DAYS - np.pi
 
     dopplerWings = _calcDopplerWings(F0, Alpha, Delta, lonStart, lonStop)
     spindownWings = _calcSpindownWings(F0, F1, minStartTime, maxStartTime)
diff --git a/pyfstat/make_sfts.py b/pyfstat/make_sfts.py
index 2a65fa9069cc55903036d90ce82427e324cc58ca..189ab7b24ef3ef614aff2a8ef6595d80a4975c6a 100644
--- a/pyfstat/make_sfts.py
+++ b/pyfstat/make_sfts.py
@@ -20,13 +20,33 @@ class KeyboardInterruptError(Exception):
 
 class Writer(BaseSearchClass):
     """ Instance object for generating SFTs """
+
     @helper_functions.initializer
-    def __init__(self, label='Test', tstart=700000000, duration=100*86400,
-                 tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3,
-                 Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800,
-                 outdir=".", sqrtSX=1, Band=4, detectors='H1',
-                 minStartTime=None, maxStartTime=None, add_noise=True,
-                 transientWindowType='none'):
+    def __init__(
+        self,
+        label="Test",
+        tstart=700000000,
+        duration=100 * 86400,
+        tref=None,
+        F0=30,
+        F1=1e-10,
+        F2=0,
+        Alpha=5e-3,
+        Delta=6e-2,
+        h0=0.1,
+        cosi=0.0,
+        psi=0.0,
+        phi=0,
+        Tsft=1800,
+        outdir=".",
+        sqrtSX=1,
+        Band=4,
+        detectors="H1",
+        minStartTime=None,
+        maxStartTime=None,
+        add_noise=True,
+        transientWindowType="none",
+    ):
         """
         Parameters
         ----------
@@ -53,7 +73,7 @@ class Writer(BaseSearchClass):
         self.calculate_fmin_Band()
 
         self.tbounds = [self.tstart, self.tend]
-        logging.info('Using segment boundaries {}'.format(self.tbounds))
+        logging.info("Using segment boundaries {}".format(self.tbounds))
 
     def basic_setup(self):
         self.tstart = int(self.tstart)
@@ -65,7 +85,7 @@ class Writer(BaseSearchClass):
             self.maxStartTime = self.tend
         self.minStartTime = int(self.minStartTime)
         self.maxStartTime = int(self.maxStartTime)
-        self.duration_days = (self.tend-self.tstart) / 86400
+        self.duration_days = (self.tend - self.tstart) / 86400
 
         self.data_duration = self.maxStartTime - self.minStartTime
         numSFTs = int(float(self.data_duration) / self.Tsft)
@@ -79,23 +99,28 @@ class Writer(BaseSearchClass):
         self.config_file_name = "{}/{}.cff".format(self.outdir, self.label)
         self.sftfilenames = [
             lalpulsar.OfficialSFTFilename(
-                dets[0], dets[1], numSFTs, self.Tsft, self.minStartTime,
-                self.data_duration, self.label)
-            for dets in self.detectors.split(',')]
-        self.sftfilepath = ';'.join([
-            '{}/{}'.format(self.outdir, fn) for fn in self.sftfilenames])
-        self.IFOs = (
-            ",".join(['"{}"'.format(d) for d in self.detectors.split(",")]))
+                dets[0],
+                dets[1],
+                numSFTs,
+                self.Tsft,
+                self.minStartTime,
+                self.data_duration,
+                self.label,
+            )
+            for dets in self.detectors.split(",")
+        ]
+        self.sftfilepath = ";".join(
+            ["{}/{}".format(self.outdir, fn) for fn in self.sftfilenames]
+        )
+        self.IFOs = ",".join(['"{}"'.format(d) for d in self.detectors.split(",")])
 
     def make_data(self):
-        ''' A convienience wrapper to generate a cff file then sfts '''
+        """ A convienience wrapper to generate a cff file then sfts """
         self.make_cff()
         self.run_makefakedata()
 
-    def get_base_template(self, i, Alpha, Delta, h0, cosi, psi, phi, F0,
-                          F1, F2, tref):
-        return (
-"""[TS{}]
+    def get_base_template(self, i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref):
+        return """[TS{}]
 Alpha = {:1.18e}
 Delta = {:1.18e}
 h0 = {:1.18e}
@@ -105,35 +130,100 @@ phi0 = {:1.18e}
 Freq = {:1.18e}
 f1dot = {:1.18e}
 f2dot = {:1.18e}
-refTime = {:10.6f}""")
+refTime = {:10.6f}"""
 
     def get_single_config_line_cw(
-            self, i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref):
-        template = (self.get_base_template(
-            i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref) + """\n""")
-        return template.format(
-            i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref)
+        self, i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref
+    ):
+        template = (
+            self.get_base_template(
+                i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref
+            )
+            + """\n"""
+        )
+        return template.format(i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref)
 
     def get_single_config_line_tcw(
-            self, i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref,
-            window, tstart, duration_days):
-        template = (self.get_base_template(
-            i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref) + """
+        self,
+        i,
+        Alpha,
+        Delta,
+        h0,
+        cosi,
+        psi,
+        phi,
+        F0,
+        F1,
+        F2,
+        tref,
+        window,
+        tstart,
+        duration_days,
+    ):
+        template = (
+            self.get_base_template(
+                i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref
+            )
+            + """
 transientWindowType = {:s}
 transientStartTime = {:10.0f}
-transientTau = {:10.0f}\n""")
-        return template.format(i, Alpha, Delta, h0, cosi, psi, phi, F0, F1,
-                               F2, tref, window, tstart, duration_days*86400)
-
-    def get_single_config_line(self, i, Alpha, Delta, h0, cosi, psi, phi, F0,
-                               F1, F2, tref, window, tstart, duration_days):
-        if window == 'none':
+transientTau = {:10.0f}\n"""
+        )
+        return template.format(
+            i,
+            Alpha,
+            Delta,
+            h0,
+            cosi,
+            psi,
+            phi,
+            F0,
+            F1,
+            F2,
+            tref,
+            window,
+            tstart,
+            duration_days * 86400,
+        )
+
+    def get_single_config_line(
+        self,
+        i,
+        Alpha,
+        Delta,
+        h0,
+        cosi,
+        psi,
+        phi,
+        F0,
+        F1,
+        F2,
+        tref,
+        window,
+        tstart,
+        duration_days,
+    ):
+        if window == "none":
             return self.get_single_config_line_cw(
-                i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref)
+                i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref
+            )
         else:
             return self.get_single_config_line_tcw(
-                i, Alpha, Delta, h0, cosi, psi, phi, F0, F1, F2, tref, window,
-                tstart, duration_days)
+                i,
+                Alpha,
+                Delta,
+                h0,
+                cosi,
+                psi,
+                phi,
+                F0,
+                F1,
+                F2,
+                tref,
+                window,
+                tstart,
+                duration_days,
+            )
 
     def make_cff(self):
         """
@@ -142,9 +232,21 @@ transientTau = {:10.0f}\n""")
         """
 
         content = self.get_single_config_line(
-            0, self.Alpha, self.Delta, self.h0, self.cosi, self.psi,
-            self.phi, self.F0, self.F1, self.F2, self.tref,
-            self.transientWindowType, self.tstart, self.duration_days)
+            0,
+            self.Alpha,
+            self.Delta,
+            self.h0,
+            self.cosi,
+            self.psi,
+            self.phi,
+            self.F0,
+            self.F1,
+            self.F2,
+            self.tref,
+            self.transientWindowType,
+            self.tstart,
+            self.duration_days,
+        )
 
         if self.check_if_cff_file_needs_rewritting(content):
             config_file = open(self.config_file_name, "w+")
@@ -152,7 +254,7 @@ transientTau = {:10.0f}\n""")
             config_file.close()
 
     def calculate_fmin_Band(self):
-        self.fmin = self.F0 - .5 * self.Band
+        self.fmin = self.F0 - 0.5 * self.Band
 
     def check_cached_data_okay_to_use(self, cl_mfd):
         """ Check if cached data exists and, if it does, if it can be used """
@@ -160,32 +262,34 @@ transientTau = {:10.0f}\n""")
         getmtime = os.path.getmtime
 
         if os.path.isfile(self.sftfilepath) is False:
-            logging.info('No SFT file matching {} found'.format(
-                self.sftfilepath))
+            logging.info("No SFT file matching {} found".format(self.sftfilepath))
             return False
         else:
-            logging.info('Matching SFT file found')
+            logging.info("Matching SFT file found")
 
         if getmtime(self.sftfilepath) < getmtime(self.config_file_name):
             logging.info(
-                ('The config file {} has been modified since the sft file {} '
-                 + 'was created').format(
-                    self.config_file_name, self.sftfilepath))
+                (
+                    "The config file {} has been modified since the sft file {} "
+                    + "was created"
+                ).format(self.config_file_name, self.sftfilepath)
+            )
             return False
 
         logging.info(
-            'The config file {} is older than the sft file {}'.format(
-                self.config_file_name, self.sftfilepath))
-        logging.info('Checking contents of cff file')
-        cl_dump = 'lalapps_SFTdumpheader {} | head -n 20'.format(
-            self.sftfilepath)
+            "The config file {} is older than the sft file {}".format(
+                self.config_file_name, self.sftfilepath
+            )
+        )
+        logging.info("Checking contents of cff file")
+        cl_dump = "lalapps_SFTdumpheader {} | head -n 20".format(self.sftfilepath)
         output = helper_functions.run_commandline(cl_dump)
-        calls = [line for line in output.split('\n') if line[:3] == 'lal']
+        calls = [line for line in output.split("\n") if line[:3] == "lal"]
         if calls[0] == cl_mfd:
-            logging.info('Contents matched, use old sft file')
+            logging.info("Contents matched, use old sft file")
             return True
         else:
-            logging.info('Contents unmatched, create new sft file')
+            logging.info("Contents unmatched, create new sft file")
             return False
 
     def check_if_cff_file_needs_rewritting(self, content):
@@ -195,24 +299,24 @@ transientTau = {:10.0f}\n""")
         overwriting to allow cached data to be used
         """
         if os.path.isfile(self.config_file_name) is False:
-            logging.info('No config file {} found'.format(
-                self.config_file_name))
+            logging.info("No config file {} found".format(self.config_file_name))
             return True
         else:
-            logging.info('Config file {} already exists'.format(
-                self.config_file_name))
+            logging.info("Config file {} already exists".format(self.config_file_name))
 
-        with open(self.config_file_name, 'r') as f:
+        with open(self.config_file_name, "r") as f:
             file_content = f.read()
             if file_content == content:
                 logging.info(
-                    'File contents match, no update of {} required'.format(
-                        self.config_file_name))
+                    "File contents match, no update of {} required".format(
+                        self.config_file_name
+                    )
+                )
                 return False
             else:
                 logging.info(
-                    'File contents unmatched, updating {}'.format(
-                        self.config_file_name))
+                    "File contents unmatched, updating {}".format(self.config_file_name)
+                )
                 return True
 
     def run_makefakedata(self):
@@ -225,29 +329,27 @@ transientTau = {:10.0f}\n""")
             pass
 
         cl_mfd = []
-        cl_mfd.append('lalapps_Makefakedata_v5')
-        cl_mfd.append('--outSingleSFT=TRUE')
+        cl_mfd.append("lalapps_Makefakedata_v5")
+        cl_mfd.append("--outSingleSFT=TRUE")
         cl_mfd.append('--outSFTdir="{}"'.format(self.outdir))
         cl_mfd.append('--outLabel="{}"'.format(self.label))
-        cl_mfd.append('--IFOs={}'.format(self.IFOs))
+        cl_mfd.append("--IFOs={}".format(self.IFOs))
         if self.add_noise:
             cl_mfd.append('--sqrtSX="{}"'.format(self.sqrtSX))
         if self.minStartTime is None:
-            cl_mfd.append('--startTime={:0.0f}'.format(float(self.tstart)))
+            cl_mfd.append("--startTime={:0.0f}".format(float(self.tstart)))
         else:
-            cl_mfd.append('--startTime={:0.0f}'.format(
-                float(self.minStartTime)))
+            cl_mfd.append("--startTime={:0.0f}".format(float(self.minStartTime)))
         if self.maxStartTime is None:
-            cl_mfd.append('--duration={}'.format(int(self.duration)))
+            cl_mfd.append("--duration={}".format(int(self.duration)))
         else:
             data_duration = self.maxStartTime - self.minStartTime
-            cl_mfd.append('--duration={}'.format(int(data_duration)))
-        cl_mfd.append('--fmin={:.16g}'.format(self.fmin))
-        cl_mfd.append('--Band={:.16g}'.format(self.Band))
-        cl_mfd.append('--Tsft={}'.format(self.Tsft))
+            cl_mfd.append("--duration={}".format(int(data_duration)))
+        cl_mfd.append("--fmin={:.16g}".format(self.fmin))
+        cl_mfd.append("--Band={:.16g}".format(self.Band))
+        cl_mfd.append("--Tsft={}".format(self.Tsft))
         if self.h0 != 0:
-            cl_mfd.append('--injectionSources="{}"'.format(
-                self.config_file_name))
+            cl_mfd.append('--injectionSources="{}"'.format(self.config_file_name))
 
         cl_mfd = " ".join(cl_mfd)
 
@@ -257,23 +359,56 @@ transientTau = {:10.0f}\n""")
     def predict_fstat(self):
         """ Wrapper to lalapps_PredictFstat """
         twoF_expected, twoF_sigma = predict_fstat(
-            self.h0, self.cosi, self.psi, self.Alpha, self.Delta, self.F0,
-            self.sftfilepath, self.minStartTime, self.maxStartTime,
-            self.detectors, self.sqrtSX,
-            tempory_filename='{}.tmp'.format(self.label)) # detectors OR IFO?
+            self.h0,
+            self.cosi,
+            self.psi,
+            self.Alpha,
+            self.Delta,
+            self.F0,
+            self.sftfilepath,
+            self.minStartTime,
+            self.maxStartTime,
+            self.detectors,
+            self.sqrtSX,
+            tempory_filename="{}.tmp".format(self.label),
+        )  # detectors OR IFO?
         return twoF_expected
 
 
 class GlitchWriter(Writer):
     """ Instance object for generating SFTs containing glitch signals """
+
     @helper_functions.initializer
-    def __init__(self, label='Test', tstart=700000000, duration=100*86400,
-                 dtglitch=None, delta_phi=0, delta_F0=0, delta_F1=0,
-                 delta_F2=0, tref=None, F0=30, F1=1e-10, F2=0, Alpha=5e-3,
-                 Delta=6e-2, h0=0.1, cosi=0.0, psi=0.0, phi=0, Tsft=1800,
-                 outdir=".", sqrtSX=1, Band=4, detectors='H1',
-                 minStartTime=None, maxStartTime=None, add_noise=True,
-                 transientWindowType='rect'):
+    def __init__(
+        self,
+        label="Test",
+        tstart=700000000,
+        duration=100 * 86400,
+        dtglitch=None,
+        delta_phi=0,
+        delta_F0=0,
+        delta_F1=0,
+        delta_F2=0,
+        tref=None,
+        F0=30,
+        F1=1e-10,
+        F2=0,
+        Alpha=5e-3,
+        Delta=6e-2,
+        h0=0.1,
+        cosi=0.0,
+        psi=0.0,
+        phi=0,
+        Tsft=1800,
+        outdir=".",
+        sqrtSX=1,
+        Band=4,
+        detectors="H1",
+        minStartTime=None,
+        maxStartTime=None,
+        add_noise=True,
+        transientWindowType="rect",
+    ):
         """
         Parameters
         ----------
@@ -303,12 +438,14 @@ class GlitchWriter(Writer):
         self.basic_setup()
         self.calculate_fmin_Band()
 
-        shapes = np.array([np.shape(x) for x in [self.delta_phi, self.delta_F0,
-                                                 self.delta_F1, self.delta_F2]]
-                          )
+        shapes = np.array(
+            [
+                np.shape(x)
+                for x in [self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2]
+            ]
+        )
         if not np.all(shapes == shapes[0]):
-            raise ValueError('all delta_* must be the same shape: {}'.format(
-                shapes))
+            raise ValueError("all delta_* must be the same shape: {}".format(shapes))
 
         for d in self.delta_phi, self.delta_F0, self.delta_F1, self.delta_F2:
             if np.size(d) == 1:
@@ -319,15 +456,15 @@ class GlitchWriter(Writer):
         else:
             self.dtglitch = np.atleast_1d(self.dtglitch)
             self.tglitch = self.tstart + self.dtglitch
-            self.tbounds = np.concatenate((
-                [self.tstart], self.tglitch, [self.tend]))
-        logging.info('Using segment boundaries {}'.format(self.tbounds))
+            self.tbounds = np.concatenate(([self.tstart], self.tglitch, [self.tend]))
+        logging.info("Using segment boundaries {}".format(self.tbounds))
 
         tbs = np.array(self.tbounds)
         self.durations_days = (tbs[1:] - tbs[:-1]) / 86400
 
         self.delta_thetas = np.atleast_2d(
-                np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T)
+            np.array([delta_phi, delta_F0, delta_F1, delta_F2]).T
+        )
 
     def make_cff(self):
         """
@@ -335,16 +472,28 @@ class GlitchWriter(Writer):
 
         """
 
-        thetas = self._calculate_thetas(self.theta, self.delta_thetas,
-                                        self.tbounds)
+        thetas = self._calculate_thetas(self.theta, self.delta_thetas, self.tbounds)
 
-        content = ''
-        for i, (t, d, ts) in enumerate(zip(thetas, self.durations_days,
-                                           self.tbounds[:-1])):
+        content = ""
+        for i, (t, d, ts) in enumerate(
+            zip(thetas, self.durations_days, self.tbounds[:-1])
+        ):
             line = self.get_single_config_line(
-                i, self.Alpha, self.Delta, self.h0, self.cosi, self.psi,
-                t[0], t[1], t[2], t[3], self.tref, self.transientWindowType,
-                ts, d)
+                i,
+                self.Alpha,
+                self.Delta,
+                self.h0,
+                self.cosi,
+                self.psi,
+                t[0],
+                t[1],
+                t[2],
+                t[3],
+                self.tref,
+                self.transientWindowType,
+                ts,
+                d,
+            )
 
             content += line
 
@@ -358,11 +507,29 @@ class FrequencyModulatedArtifactWriter(Writer):
     """ Instance object for generating SFTs containing artifacts """
 
     @helper_functions.initializer
-    def __init__(self, label, outdir=".", tstart=700000000,
-                 duration=86400, F0=30, F1=0, tref=None, h0=10, Tsft=1800,
-                 sqrtSX=1, Band=4, Pmod=lal.DAYSID_SI, Pmod_phi=0, Pmod_amp=1,
-                 Alpha=None, Delta=None, IFO='H1', minStartTime=None,
-                 maxStartTime=None, detectors='H1'):
+    def __init__(
+        self,
+        label,
+        outdir=".",
+        tstart=700000000,
+        duration=86400,
+        F0=30,
+        F1=0,
+        tref=None,
+        h0=10,
+        Tsft=1800,
+        sqrtSX=1,
+        Band=4,
+        Pmod=lal.DAYSID_SI,
+        Pmod_phi=0,
+        Pmod_amp=1,
+        Alpha=None,
+        Delta=None,
+        IFO="H1",
+        minStartTime=None,
+        maxStartTime=None,
+        detectors="H1",
+    ):
         """
         Parameters
         ----------
@@ -391,24 +558,28 @@ class FrequencyModulatedArtifactWriter(Writer):
         if os.path.isdir(self.outdir) is False:
             os.makedirs(self.outdir)
         if tref is None:
-            raise ValueError('Input `tref` not specified')
+            raise ValueError("Input `tref` not specified")
 
         self.nsfts = int(np.ceil(self.duration / self.Tsft))
-        self.duration = self.duration / 86400.
+        self.duration = self.duration / 86400.0
         self.calculate_fmin_Band()
 
         self.cosi = 0
         self.Fmax = F0
 
         if Alpha is not None and Delta is not None:
-            self.n = np.array([np.cos(Alpha)*np.cos(Delta),
-                               np.sin(Alpha)*np.cos(Delta),
-                               np.sin(Delta)])
+            self.n = np.array(
+                [
+                    np.cos(Alpha) * np.cos(Delta),
+                    np.sin(Alpha) * np.cos(Delta),
+                    np.sin(Delta),
+                ]
+            )
 
     def get_frequency(self, t):
         DeltaFDrift = self.F1 * (t - self.tref)
 
-        phir = 2*np.pi*t/self.Pmod + self.Pmod_phi
+        phir = 2 * np.pi * t / self.Pmod + self.Pmod_phi
 
         if self.Alpha is not None and self.Delta is not None:
             spin_posvel = lalpulsar.PosVel3D_t()
@@ -416,23 +587,34 @@ class FrequencyModulatedArtifactWriter(Writer):
             det = lal.CachedDetectors[4]
             ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
             lalpulsar.DetectorPosVel(
-                spin_posvel, orbit_posvel, lal.LIGOTimeGPS(t), det, ephems,
-                lalpulsar.DETMOTION_ORBIT)
+                spin_posvel,
+                orbit_posvel,
+                lal.LIGOTimeGPS(t),
+                det,
+                ephems,
+                lalpulsar.DETMOTION_ORBIT,
+            )
             # Pos and vel returned in units of c
             DeltaFOrbital = np.dot(self.n, orbit_posvel.vel) * self.Fmax
 
-            if self.IFO == 'H1':
+            if self.IFO == "H1":
                 Lambda = lal.LHO_4K_DETECTOR_LATITUDE_RAD
-            elif self.IFO == 'L1':
+            elif self.IFO == "L1":
                 Lambda = lal.LLO_4K_DETECTOR_LATITUDE_RAD
 
             DeltaFSpin = (
-                self.Pmod_amp*lal.REARTH_SI/lal.C_SI*2*np.pi/self.Pmod*(
-                   np.cos(self.Delta)*np.cos(Lambda)*np.sin(self.Alpha-phir)
-                   ) * self.Fmax)
+                self.Pmod_amp
+                * lal.REARTH_SI
+                / lal.C_SI
+                * 2
+                * np.pi
+                / self.Pmod
+                * (np.cos(self.Delta) * np.cos(Lambda) * np.sin(self.Alpha - phir))
+                * self.Fmax
+            )
         else:
             DeltaFOrbital = 0
-            DeltaFSpin = 2*np.pi*self.Pmod_amp/self.Pmod*np.cos(phir)
+            DeltaFSpin = 2 * np.pi * self.Pmod_amp / self.Pmod * np.cos(phir)
 
         f = self.F0 + DeltaFDrift + DeltaFOrbital + DeltaFSpin
         return f
@@ -442,34 +624,43 @@ class FrequencyModulatedArtifactWriter(Writer):
 
     def concatenate_sft_files(self):
         SFTFilename = lalpulsar.OfficialSFTFilename(
-            self.IFO[0], self.IFO[1], self.nsfts, self.Tsft, int(self.tstart),
-            int(self.duration), self.label)
+            self.IFO[0],
+            self.IFO[1],
+            self.nsfts,
+            self.Tsft,
+            int(self.tstart),
+            int(self.duration),
+            self.label,
+        )
 
         # If the file already exists, simply remove it for now (no caching
         # implemented)
         helper_functions.run_commandline(
-            'rm {}/{}'.format(self.outdir, SFTFilename), raise_error=False,
-            log_level=10)
-
-        cl_splitSFTS = (
-            'lalapps_splitSFTs -fs {} -fb {} -fe {} -o {}/{} -i {}/*sft'
-            .format(self.fmin, self.Band, self.fmin+self.Band, self.outdir,
-                    SFTFilename, self.tmp_outdir))
+            "rm {}/{}".format(self.outdir, SFTFilename), raise_error=False, log_level=10
+        )
+
+        cl_splitSFTS = "lalapps_splitSFTs -fs {} -fb {} -fe {} -o {}/{} -i {}/*sft".format(
+            self.fmin,
+            self.Band,
+            self.fmin + self.Band,
+            self.outdir,
+            SFTFilename,
+            self.tmp_outdir,
+        )
         helper_functions.run_commandline(cl_splitSFTS)
-        helper_functions.run_commandline('rm {} -r'.format(self.tmp_outdir))
-        files = glob.glob('{}/{}*'.format(self.outdir, SFTFilename))
+        helper_functions.run_commandline("rm {} -r".format(self.tmp_outdir))
+        files = glob.glob("{}/{}*".format(self.outdir, SFTFilename))
         if len(files) == 1:
             fn = files[0]
-            fn_new = fn.split('.')[0] + '.sft'
-            helper_functions.run_commandline('mv {} {}'.format(
-                fn, fn_new))
+            fn_new = fn.split(".")[0] + ".sft"
+            helper_functions.run_commandline("mv {} {}".format(fn, fn_new))
         else:
             raise IOError(
-                'Attempted to rename file, but multiple files found: {}'
-                .format(files))
+                "Attempted to rename file, but multiple files found: {}".format(files)
+            )
 
     def pre_compute_evolution(self):
-        logging.info('Precomputing evolution parameters')
+        logging.info("Precomputing evolution parameters")
         self.lineFreqs = []
         self.linePhis = []
         self.lineh0s = []
@@ -478,21 +669,27 @@ class FrequencyModulatedArtifactWriter(Writer):
         linePhi = 0
         lineFreq_old = 0
         for i in tqdm(list(range(self.nsfts))):
-            mid_time = self.tstart + (i+.5)*self.Tsft
+            mid_time = self.tstart + (i + 0.5) * self.Tsft
             lineFreq = self.get_frequency(mid_time)
 
             self.mid_times.append(mid_time)
             self.lineFreqs.append(lineFreq)
-            self.linePhis.append(linePhi + np.pi*self.Tsft*(lineFreq_old+lineFreq))
+            self.linePhis.append(
+                linePhi + np.pi * self.Tsft * (lineFreq_old + lineFreq)
+            )
             self.lineh0s.append(self.get_h0(mid_time))
 
             lineFreq_old = lineFreq
 
     def make_ith_sft(self, i):
         try:
-            self.run_makefakedata_v4(self.mid_times[i], self.lineFreqs[i],
-                                     self.linePhis[i], self.lineh0s[i],
-                                     self.tmp_outdir)
+            self.run_makefakedata_v4(
+                self.mid_times[i],
+                self.lineFreqs[i],
+                self.linePhis[i],
+                self.lineh0s[i],
+                self.tmp_outdir,
+            )
         except KeyboardInterrupt:
             raise KeyboardInterruptError()
 
@@ -500,31 +697,39 @@ class FrequencyModulatedArtifactWriter(Writer):
         self.maxStartTime = None
         self.duration = self.Tsft
 
-        self.tmp_outdir = '{}/{}_tmp'.format(self.outdir, self.label)
+        self.tmp_outdir = "{}/{}_tmp".format(self.outdir, self.label)
         if os.path.isdir(self.tmp_outdir) is True:
             raise ValueError(
-                'Temporary directory {} already exists, please rename'.format(
-                    self.tmp_outdir))
+                "Temporary directory {} already exists, please rename".format(
+                    self.tmp_outdir
+                )
+            )
         else:
             os.makedirs(self.tmp_outdir)
 
         self.pre_compute_evolution()
 
-        logging.info('Generating SFTs')
+        logging.info("Generating SFTs")
 
-        if args.N > 1 and pkgutil.find_loader('pathos') is not None:
+        if args.N > 1 and pkgutil.find_loader("pathos") is not None:
             import pathos.pools
-            logging.info('Using {} threads'.format(args.N))
+
+            logging.info("Using {} threads".format(args.N))
             try:
                 with pathos.pools.ProcessPool(args.N) as p:
-                    list(tqdm(p.imap(self.make_ith_sft, list(range(self.nsfts))),
-                              total=self.nsfts))
+                    list(
+                        tqdm(
+                            p.imap(self.make_ith_sft, list(range(self.nsfts))),
+                            total=self.nsfts,
+                        )
+                    )
             except KeyboardInterrupt:
                 p.terminate()
         else:
             logging.info(
                 "No multiprocessing requested or `pathos` not install, cont."
-                " without multiprocessing")
+                " without multiprocessing"
+            )
             for i in tqdm(list(range(self.nsfts))):
                 self.make_ith_sft(i)
 
@@ -533,27 +738,28 @@ class FrequencyModulatedArtifactWriter(Writer):
     def run_makefakedata_v4(self, mid_time, lineFreq, linePhi, h0, tmp_outdir):
         """ Generate the sft data using the --lineFeature option """
         cl_mfd = []
-        cl_mfd.append('lalapps_Makefakedata_v4')
-        cl_mfd.append('--outSingleSFT=FALSE')
+        cl_mfd.append("lalapps_Makefakedata_v4")
+        cl_mfd.append("--outSingleSFT=FALSE")
         cl_mfd.append('--outSFTbname="{}"'.format(tmp_outdir))
-        cl_mfd.append('--IFO={}'.format(self.IFO))
+        cl_mfd.append("--IFO={}".format(self.IFO))
         cl_mfd.append('--noiseSqrtSh="{}"'.format(self.sqrtSX))
-        cl_mfd.append('--startTime={:0.0f}'.format(mid_time-self.Tsft/2.0))
-        cl_mfd.append('--refTime={:0.0f}'.format(mid_time))
-        cl_mfd.append('--duration={}'.format(int(self.duration)))
-        cl_mfd.append('--fmin={:.16g}'.format(self.fmin))
-        cl_mfd.append('--Band={:.16g}'.format(self.Band))
-        cl_mfd.append('--Tsft={}'.format(self.Tsft))
-        cl_mfd.append('--Freq={}'.format(lineFreq))
-        cl_mfd.append('--phi0={}'.format(linePhi))
-        cl_mfd.append('--h0={}'.format(h0))
-        cl_mfd.append('--cosi={}'.format(self.cosi))
-        cl_mfd.append('--lineFeature=TRUE')
+        cl_mfd.append("--startTime={:0.0f}".format(mid_time - self.Tsft / 2.0))
+        cl_mfd.append("--refTime={:0.0f}".format(mid_time))
+        cl_mfd.append("--duration={}".format(int(self.duration)))
+        cl_mfd.append("--fmin={:.16g}".format(self.fmin))
+        cl_mfd.append("--Band={:.16g}".format(self.Band))
+        cl_mfd.append("--Tsft={}".format(self.Tsft))
+        cl_mfd.append("--Freq={}".format(lineFreq))
+        cl_mfd.append("--phi0={}".format(linePhi))
+        cl_mfd.append("--h0={}".format(h0))
+        cl_mfd.append("--cosi={}".format(self.cosi))
+        cl_mfd.append("--lineFeature=TRUE")
         cl_mfd = " ".join(cl_mfd)
         helper_functions.run_commandline(cl_mfd, log_level=10)
 
 
 class FrequencyAmplitudeModulatedArtifactWriter(FrequencyModulatedArtifactWriter):
     """ Instance object for generating SFTs containing artifacts """
+
     def get_h0(self, t):
-            return self.h0*np.sin(2*np.pi*t/self.Pmod+self. Pmod_phi)
+        return self.h0 * np.sin(2 * np.pi * t / self.Pmod + self.Pmod_phi)
diff --git a/pyfstat/mcmc_based_searches.py b/pyfstat/mcmc_based_searches.py
index f30c9672e71a58053d3cbe03815f3cdadee4a046..78a1af122445ed56720a66769c7d88d1d1d8cba8 100644
--- a/pyfstat/mcmc_based_searches.py
+++ b/pyfstat/mcmc_based_searches.py
@@ -102,22 +102,57 @@ class MCMCSearch(core.BaseSearchClass):
     """
 
     symbol_dictionary = dict(
-        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$',
-        Delta='$\delta$', asini='asini', period='P', ecc='ecc', tp='tp',
-        argp='argp')
+        F0="$f$",
+        F1="$\dot{f}$",
+        F2="$\ddot{f}$",
+        Alpha=r"$\alpha$",
+        Delta="$\delta$",
+        asini="asini",
+        period="P",
+        ecc="ecc",
+        tp="tp",
+        argp="argp",
+    )
     unit_dictionary = dict(
-        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad',
-        asini='', period='s', ecc='', tp='', argp='')
+        F0="Hz",
+        F1="Hz/s",
+        F2="Hz/s$^2$",
+        Alpha=r"rad",
+        Delta="rad",
+        asini="",
+        period="s",
+        ecc="",
+        tp="",
+        argp="",
+    )
     transform_dictionary = {}
 
-    def __init__(self, theta_prior, tref, label, outdir='data',
-                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
-                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
-                 log10beta_min=-5, theta_initial=None,
-                 rhohatmax=1000, binary=False, BSGL=False,
-                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
-                 injectSources=None, assumeSqrtSX=None,
-                 transientWindowType=None, tCWFstatMapVersion='lal'):
+    def __init__(
+        self,
+        theta_prior,
+        tref,
+        label,
+        outdir="data",
+        minStartTime=None,
+        maxStartTime=None,
+        sftfilepattern=None,
+        detectors=None,
+        nsteps=[100, 100],
+        nwalkers=100,
+        ntemps=1,
+        log10beta_min=-5,
+        theta_initial=None,
+        rhohatmax=1000,
+        binary=False,
+        BSGL=False,
+        SSBprec=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+        transientWindowType=None,
+        tCWFstatMapVersion="lal",
+    ):
 
         self.theta_prior = theta_prior
         self.tref = tref
@@ -146,14 +181,14 @@ class MCMCSearch(core.BaseSearchClass):
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self._add_log_file()
-        logging.info('Set-up MCMC search for model {}'.format(self.label))
+        logging.info("Set-up MCMC search for model {}".format(self.label))
         if sftfilepattern:
-            logging.info('Using data {}'.format(self.sftfilepattern))
+            logging.info("Using data {}".format(self.sftfilepattern))
         else:
-            logging.info('No sftfilepattern given')
+            logging.info("No sftfilepattern given")
         if injectSources:
-            logging.info('Inject sources: {}'.format(injectSources))
-        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
+            logging.info("Inject sources: {}".format(injectSources))
+        self.pickle_path = "{}/{}_saved_data.p".format(self.outdir, self.label)
         self._unpack_input_theta()
         self.ndim = len(self.theta_keys)
         if self.log10beta_min:
@@ -162,62 +197,74 @@ class MCMCSearch(core.BaseSearchClass):
             self.betas = None
 
         if args.clean and os.path.isfile(self.pickle_path):
-            os.rename(self.pickle_path, self.pickle_path+".old")
+            os.rename(self.pickle_path, self.pickle_path + ".old")
 
         self._set_likelihoodcoef()
         self._log_input()
 
     def _set_likelihoodcoef(self):
-        self.likelihoodcoef = np.log(70./self.rhohatmax**4)
+        self.likelihoodcoef = np.log(70.0 / self.rhohatmax ** 4)
 
     def _log_input(self):
-        logging.info('theta_prior = {}'.format(self.theta_prior))
-        logging.info('nwalkers={}'.format(self.nwalkers))
-        logging.info('nsteps = {}'.format(self.nsteps))
-        logging.info('ntemps = {}'.format(self.ntemps))
-        logging.info('log10beta_min = {}'.format(
-            self.log10beta_min))
+        logging.info("theta_prior = {}".format(self.theta_prior))
+        logging.info("nwalkers={}".format(self.nwalkers))
+        logging.info("nsteps = {}".format(self.nsteps))
+        logging.info("ntemps = {}".format(self.ntemps))
+        logging.info("log10beta_min = {}".format(self.log10beta_min))
 
     def _initiate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         self.search = core.ComputeFstat(
-            tref=self.tref, sftfilepattern=self.sftfilepattern,
-            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
-            detectors=self.detectors, BSGL=self.BSGL,
+            tref=self.tref,
+            sftfilepattern=self.sftfilepattern,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
+            detectors=self.detectors,
+            BSGL=self.BSGL,
             transientWindowType=self.transientWindowType,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            binary=self.binary, injectSources=self.injectSources,
-            assumeSqrtSX=self.assumeSqrtSX, SSBprec=self.SSBprec,
-            tCWFstatMapVersion=self.tCWFstatMapVersion)
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            binary=self.binary,
+            injectSources=self.injectSources,
+            assumeSqrtSX=self.assumeSqrtSX,
+            SSBprec=self.SSBprec,
+            tCWFstatMapVersion=self.tCWFstatMapVersion,
+        )
         if self.minStartTime is None:
             self.minStartTime = self.search.minStartTime
         if self.maxStartTime is None:
             self.maxStartTime = self.search.maxStartTime
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
-        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
-             zip(theta_vals, theta_keys)]
+        H = [
+            self._generic_lnprior(**theta_prior[key])(p)
+            for p, key in zip(theta_vals, theta_keys)
+        ]
         return np.sum(H)
 
     def logl(self, theta, search):
         for j, theta_i in enumerate(self.theta_idxs):
             self.fixed_theta[theta_i] = theta[j]
         twoF = search.get_fullycoherent_twoF(
-            self.minStartTime, self.maxStartTime, *self.fixed_theta)
-        return twoF/2.0 + self.likelihoodcoef
+            self.minStartTime, self.maxStartTime, *self.fixed_theta
+        )
+        return twoF / 2.0 + self.likelihoodcoef
 
     def _unpack_input_theta(self):
-        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']
+        full_theta_keys = ["F0", "F1", "F2", "Alpha", "Delta"]
         if self.binary:
-            full_theta_keys += [
-                'asini', 'period', 'ecc', 'tp', 'argp']
+            full_theta_keys += ["asini", "period", "ecc", "tp", "argp"]
         full_theta_keys_copy = copy.copy(full_theta_keys)
 
-        full_theta_symbols = ['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
-                              r'$\delta$']
+        full_theta_symbols = [
+            "$f$",
+            "$\dot{f}$",
+            "$\ddot{f}$",
+            r"$\alpha$",
+            r"$\delta$",
+        ]
         if self.binary:
-            full_theta_symbols += [
-                'asini', 'period', 'ecc', 'tp', 'argp']
+            full_theta_symbols += ["asini", "period", "ecc", "tp", "argp"]
 
         self.theta_keys = []
         fixed_theta_dict = {}
@@ -229,14 +276,16 @@ class MCMCSearch(core.BaseSearchClass):
                 fixed_theta_dict[key] = val
             else:
                 raise ValueError(
-                    'Type {} of {} in theta not recognised'.format(
-                        type(val), key))
+                    "Type {} of {} in theta not recognised".format(type(val), key)
+                )
             full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
 
         if len(full_theta_keys_copy) > 0:
-            raise ValueError(('Input dictionary `theta` is missing the'
-                              'following keys: {}').format(
-                                  full_theta_keys_copy))
+            raise ValueError(
+                ("Input dictionary `theta` is missing the" "following keys: {}").format(
+                    full_theta_keys_copy
+                )
+            )
 
         self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
         self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
@@ -248,41 +297,44 @@ class MCMCSearch(core.BaseSearchClass):
         self.theta_keys = [self.theta_keys[i] for i in idxs]
 
     def _evaluate_logpost(self, p0vec):
-        init_logp = np.array([
-            self.logp(p, self.theta_prior, self.theta_keys, self.search)
-            for p in p0vec])
-        init_logl = np.array([
-            self.logl(p, self.search)
-            for p in p0vec])
+        init_logp = np.array(
+            [
+                self.logp(p, self.theta_prior, self.theta_keys, self.search)
+                for p in p0vec
+            ]
+        )
+        init_logl = np.array([self.logl(p, self.search) for p in p0vec])
         return init_logl + init_logp
 
     def _check_initial_points(self, p0):
         for nt in range(self.ntemps):
-            logging.info('Checking temperature {} chains'.format(nt))
+            logging.info("Checking temperature {} chains".format(nt))
             num = sum(self._evaluate_logpost(p0[nt]) == -np.inf)
             if num > 0:
                 logging.warning(
-                    'Of {} initial values, {} are -np.inf due to the prior'
-                    .format(len(p0[0]), num))
-                p0 = self._generate_new_p0_to_fix_initial_points(
-                    p0, nt)
+                    "Of {} initial values, {} are -np.inf due to the prior".format(
+                        len(p0[0]), num
+                    )
+                )
+                p0 = self._generate_new_p0_to_fix_initial_points(p0, nt)
 
     def _generate_new_p0_to_fix_initial_points(self, p0, nt):
-        logging.info('Attempting to correct intial values')
+        logging.info("Attempting to correct intial values")
         init_logpost = self._evaluate_logpost(p0[nt])
         idxs = np.arange(self.nwalkers)[init_logpost == -np.inf]
         count = 0
         while sum(init_logpost == -np.inf) > 0 and count < 100:
             for j in idxs:
-                p0[nt][j] = (p0[nt][np.random.randint(0, self.nwalkers)]*(
-                             1+np.random.normal(0, 1e-10, self.ndim)))
+                p0[nt][j] = p0[nt][np.random.randint(0, self.nwalkers)] * (
+                    1 + np.random.normal(0, 1e-10, self.ndim)
+                )
             init_logpost = self._evaluate_logpost(p0[nt])
             count += 1
 
         if sum(init_logpost == -np.inf) > 0:
-            logging.info('Failed to fix initial priors')
+            logging.info("Failed to fix initial priors")
         else:
-            logging.info('Suceeded to fix initial priors')
+            logging.info("Suceeded to fix initial priors")
 
         return p0
 
@@ -305,136 +357,142 @@ class MCMCSearch(core.BaseSearchClass):
 
         """
 
-        logging.info('Setting up initialisation with nburn0={}, scatter_val={}'
-                     .format(nburn0, scatter_val))
+        logging.info(
+            "Setting up initialisation with nburn0={}, scatter_val={}".format(
+                nburn0, scatter_val
+            )
+        )
         self.nsteps = [nburn0] + self.nsteps
         self.scatter_val = scatter_val
 
-#    def setup_burnin_convergence_testing(
-#            self, n=10, test_type='autocorr', windowed=False, **kwargs):
-#        """ Set up convergence testing during the MCMC simulation
-#
-#        Parameters
-#        ----------
-#        n: int
-#            Number of steps after which to test convergence
-#        test_type: str ['autocorr', 'GR']
-#            If 'autocorr' use the exponential autocorrelation time (kwargs
-#            passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
-#            statistic (kwargs passed to `get_GR_convergence`)
-#        windowed: bool
-#            If True, only calculate the convergence test in a window of length
-#            `n`
-#        **kwargs:
-#            Passed to either `_test_autocorr_convergence()` or
-#            `_test_GR_convergence()` depending on `test_type`.
-#
-#        """
-#        logging.info('Setting up convergence testing')
-#        self.convergence_n = n
-#        self.convergence_windowed = windowed
-#        self.convergence_test_type = test_type
-#        self.convergence_kwargs = kwargs
-#        self.convergence_diagnostic = []
-#        self.convergence_diagnosticx = []
-#        if test_type in ['autocorr']:
-#            self._get_convergence_test = self._test_autocorr_convergence
-#        elif test_type in ['GR']:
-#            self._get_convergence_test = self._test_GR_convergence
-#        else:
-#            raise ValueError('test_type {} not understood'.format(test_type))
-#
-#
-#    def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
-#        try:
-#            acors = np.zeros((self.ntemps, self.ndim))
-#            for temp in range(self.ntemps):
-#                if self.convergence_windowed:
-#                    j = i-self.convergence_n
-#                else:
-#                    j = 0
-#                x = np.mean(sampler.chain[temp, :, j:i, :], axis=0)
-#                acors[temp, :] = emcee.autocorr.exponential_time(x)
-#            c = np.max(acors, axis=0)
-#        except emcee.autocorr.AutocorrError:
-#            logging.info('Failed to calculate exponential autocorrelation')
-#            c = np.zeros(self.ndim) + np.nan
-#        except AttributeError:
-#            logging.info('Unable to calculate exponential autocorrelation')
-#            c = np.zeros(self.ndim) + np.nan
-#
-#        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
-#        self.convergence_diagnostic.append(list(c))
-#
-#        if test:
-#            return i > n_cut * np.max(c)
-#
-#    def _test_GR_convergence(self, i, sampler, test=True, R=1.1):
-#        if self.convergence_windowed:
-#            s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :]
-#        else:
-#            s = sampler.chain[0, :, :i+1, :]
-#        N = float(self.convergence_n)
-#        M = float(self.nwalkers)
-#        W = np.mean(np.var(s, axis=1), axis=0)
-#        per_walker_mean = np.mean(s, axis=1)
-#        mean = np.mean(per_walker_mean, axis=0)
-#        B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
-#        Vhat = (N-1)/N * W + (M+1)/(M*N) * B
-#        c = np.sqrt(Vhat/W)
-#        self.convergence_diagnostic.append(c)
-#        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
-#
-#        if test and np.max(c) < R:
-#            return True
-#        else:
-#            return False
-#
-#    def _test_convergence(self, i, sampler, **kwargs):
-#        if np.mod(i+1, self.convergence_n) == 0:
-#            return self._get_convergence_test(i, sampler, **kwargs)
-#        else:
-#            return False
-#
-#    def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0):
-#        logging.info('Running {} burn-in steps with convergence testing'
-#                     .format(nburn))
-#        iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
-#        for i, output in enumerate(iterator):
-#            if self._test_convergence(i, sampler, test=True,
-#                                      **self.convergence_kwargs):
-#                logging.info(
-#                    'Converged at {} before max number {} of steps reached'
-#                    .format(i, nburn))
-#                self.convergence_idx = i
-#                break
-#        iterator.close()
-#        logging.info('Running {} production steps'.format(nprod))
-#        j = nburn
-#        iterator = tqdm(sampler.sample(output[0], iterations=nprod),
-#                        total=nprod)
-#        for result in iterator:
-#            self._test_convergence(j, sampler, test=False,
-#                                   **self.convergence_kwargs)
-#            j += 1
-#        return sampler
+    #    def setup_burnin_convergence_testing(
+    #            self, n=10, test_type='autocorr', windowed=False, **kwargs):
+    #        """ Set up convergence testing during the MCMC simulation
+    #
+    #        Parameters
+    #        ----------
+    #        n: int
+    #            Number of steps after which to test convergence
+    #        test_type: str ['autocorr', 'GR']
+    #            If 'autocorr' use the exponential autocorrelation time (kwargs
+    #            passed to `get_autocorr_convergence`). If 'GR' use the Gelman-Rubin
+    #            statistic (kwargs passed to `get_GR_convergence`)
+    #        windowed: bool
+    #            If True, only calculate the convergence test in a window of length
+    #            `n`
+    #        **kwargs:
+    #            Passed to either `_test_autocorr_convergence()` or
+    #            `_test_GR_convergence()` depending on `test_type`.
+    #
+    #        """
+    #        logging.info('Setting up convergence testing')
+    #        self.convergence_n = n
+    #        self.convergence_windowed = windowed
+    #        self.convergence_test_type = test_type
+    #        self.convergence_kwargs = kwargs
+    #        self.convergence_diagnostic = []
+    #        self.convergence_diagnosticx = []
+    #        if test_type in ['autocorr']:
+    #            self._get_convergence_test = self._test_autocorr_convergence
+    #        elif test_type in ['GR']:
+    #            self._get_convergence_test = self._test_GR_convergence
+    #        else:
+    #            raise ValueError('test_type {} not understood'.format(test_type))
+    #
+    #
+    #    def _test_autocorr_convergence(self, i, sampler, test=True, n_cut=5):
+    #        try:
+    #            acors = np.zeros((self.ntemps, self.ndim))
+    #            for temp in range(self.ntemps):
+    #                if self.convergence_windowed:
+    #                    j = i-self.convergence_n
+    #                else:
+    #                    j = 0
+    #                x = np.mean(sampler.chain[temp, :, j:i, :], axis=0)
+    #                acors[temp, :] = emcee.autocorr.exponential_time(x)
+    #            c = np.max(acors, axis=0)
+    #        except emcee.autocorr.AutocorrError:
+    #            logging.info('Failed to calculate exponential autocorrelation')
+    #            c = np.zeros(self.ndim) + np.nan
+    #        except AttributeError:
+    #            logging.info('Unable to calculate exponential autocorrelation')
+    #            c = np.zeros(self.ndim) + np.nan
+    #
+    #        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
+    #        self.convergence_diagnostic.append(list(c))
+    #
+    #        if test:
+    #            return i > n_cut * np.max(c)
+    #
+    #    def _test_GR_convergence(self, i, sampler, test=True, R=1.1):
+    #        if self.convergence_windowed:
+    #            s = sampler.chain[0, :, i-self.convergence_n+1:i+1, :]
+    #        else:
+    #            s = sampler.chain[0, :, :i+1, :]
+    #        N = float(self.convergence_n)
+    #        M = float(self.nwalkers)
+    #        W = np.mean(np.var(s, axis=1), axis=0)
+    #        per_walker_mean = np.mean(s, axis=1)
+    #        mean = np.mean(per_walker_mean, axis=0)
+    #        B = N / (M-1.) * np.sum((per_walker_mean-mean)**2, axis=0)
+    #        Vhat = (N-1)/N * W + (M+1)/(M*N) * B
+    #        c = np.sqrt(Vhat/W)
+    #        self.convergence_diagnostic.append(c)
+    #        self.convergence_diagnosticx.append(i - self.convergence_n/2.)
+    #
+    #        if test and np.max(c) < R:
+    #            return True
+    #        else:
+    #            return False
+    #
+    #    def _test_convergence(self, i, sampler, **kwargs):
+    #        if np.mod(i+1, self.convergence_n) == 0:
+    #            return self._get_convergence_test(i, sampler, **kwargs)
+    #        else:
+    #            return False
+    #
+    #    def _run_sampler_with_conv_test(self, sampler, p0, nprod=0, nburn=0):
+    #        logging.info('Running {} burn-in steps with convergence testing'
+    #                     .format(nburn))
+    #        iterator = tqdm(sampler.sample(p0, iterations=nburn), total=nburn)
+    #        for i, output in enumerate(iterator):
+    #            if self._test_convergence(i, sampler, test=True,
+    #                                      **self.convergence_kwargs):
+    #                logging.info(
+    #                    'Converged at {} before max number {} of steps reached'
+    #                    .format(i, nburn))
+    #                self.convergence_idx = i
+    #                break
+    #        iterator.close()
+    #        logging.info('Running {} production steps'.format(nprod))
+    #        j = nburn
+    #        iterator = tqdm(sampler.sample(output[0], iterations=nprod),
+    #                        total=nprod)
+    #        for result in iterator:
+    #            self._test_convergence(j, sampler, test=False,
+    #                                   **self.convergence_kwargs)
+    #            j += 1
+    #        return sampler
 
     def _run_sampler(self, sampler, p0, nprod=0, nburn=0, window=50):
-        for result in tqdm(sampler.sample(p0, iterations=nburn+nprod),
-                           total=nburn+nprod):
+        for result in tqdm(
+            sampler.sample(p0, iterations=nburn + nprod), total=nburn + nprod
+        ):
             pass
 
-        self.mean_acceptance_fraction = np.mean(
-            sampler.acceptance_fraction, axis=1)
-        logging.info("Mean acceptance fraction: {}"
-                     .format(self.mean_acceptance_fraction))
+        self.mean_acceptance_fraction = np.mean(sampler.acceptance_fraction, axis=1)
+        logging.info(
+            "Mean acceptance fraction: {}".format(self.mean_acceptance_fraction)
+        )
         if self.ntemps > 1:
             self.tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
-            logging.info("Tswap acceptance fraction: {}"
-                         .format(sampler.tswap_acceptance_fraction))
+            logging.info(
+                "Tswap acceptance fraction: {}".format(
+                    sampler.tswap_acceptance_fraction
+                )
+            )
         self.autocorr_time = sampler.get_autocorr_time(window=window)
-        logging.info("Autocorrelation length: {}".format(
-            self.autocorr_time))
+        logging.info("Autocorrelation length: {}".format(self.autocorr_time))
 
         return sampler
 
@@ -447,8 +505,10 @@ class MCMCSearch(core.BaseSearchClass):
         """
         # Todo: add option to time on a machine, and move coefficients to
         # ~/.pyfstat.conf
-        if (type(self.theta_prior['Alpha']) == dict or
-                type(self.theta_prior['Delta']) == dict):
+        if (
+            type(self.theta_prior["Alpha"]) == dict
+            or type(self.theta_prior["Delta"]) == dict
+        ):
             tau0LD = 5.2e-7
             tau0T = 1.5e-8
             tau0S = 1.2e-4
@@ -458,29 +518,31 @@ class MCMCSearch(core.BaseSearchClass):
             tau0T = 1.5e-8
             tau0S = 9.1e-5
             tau0C = 5.5e-6
-        Nsfts = (self.maxStartTime - self.minStartTime) / 1800.
-        if hasattr(self, 'run_setup'):
+        Nsfts = (self.maxStartTime - self.minStartTime) / 1800.0
+        if hasattr(self, "run_setup"):
             ts = []
             for row in self.run_setup:
                 nsteps = row[0]
                 nsegs = row[1]
-                numb_evals = np.sum(nsteps)*self.nwalkers*self.ntemps
-                t = (tau0S + tau0LD*Nsfts) * numb_evals
+                numb_evals = np.sum(nsteps) * self.nwalkers * self.ntemps
+                t = (tau0S + tau0LD * Nsfts) * numb_evals
                 if nsegs > 1:
-                    t += (tau0C + tau0T*Nsfts)*nsegs*numb_evals
+                    t += (tau0C + tau0T * Nsfts) * nsegs * numb_evals
                 ts.append(t)
             time = np.sum(ts)
         else:
-            numb_evals = np.sum(self.nsteps)*self.nwalkers*self.ntemps
-            time = (tau0S + tau0LD*Nsfts) * numb_evals
-            if getattr(self, 'nsegs', 1) > 1:
-                time += (tau0C + tau0T*Nsfts)*self.nsegs*numb_evals
-
-        logging.info('Estimated run-time = {} s = {:1.0f}:{:1.0f} m'.format(
-            time, *divmod(time, 60)))
+            numb_evals = np.sum(self.nsteps) * self.nwalkers * self.ntemps
+            time = (tau0S + tau0LD * Nsfts) * numb_evals
+            if getattr(self, "nsegs", 1) > 1:
+                time += (tau0C + tau0T * Nsfts) * self.nsegs * numb_evals
+
+        logging.info(
+            "Estimated run-time = {} s = {:1.0f}:{:1.0f} m".format(
+                time, *divmod(time, 60)
+            )
+        )
 
-    def run(self, proposal_scale_factor=2, create_plots=True, window=50,
-            **kwargs):
+    def run(self, proposal_scale_factor=2, create_plots=True, window=50, **kwargs):
         """ Run the MCMC simulatation
 
         Parameters
@@ -508,24 +570,29 @@ class MCMCSearch(core.BaseSearchClass):
 
         self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
         if self.old_data_is_okay_to_use is True:
-            logging.warning('Using saved data from {}'.format(
-                self.pickle_path))
+            logging.warning("Using saved data from {}".format(self.pickle_path))
             d = self.get_saved_data_dictionary()
-            self.samples = d['samples']
-            self.lnprobs = d['lnprobs']
-            self.lnlikes = d['lnlikes']
-            self.all_lnlikelihood = d['all_lnlikelihood']
-            self.chain = d['chain']
+            self.samples = d["samples"]
+            self.lnprobs = d["lnprobs"]
+            self.lnlikes = d["lnlikes"]
+            self.all_lnlikelihood = d["all_lnlikelihood"]
+            self.chain = d["chain"]
             return
 
         self._initiate_search_object()
         self._estimate_run_time()
 
         sampler = PTSampler(
-            ntemps=self.ntemps, nwalkers=self.nwalkers, dim=self.ndim,
-            logl=self.logl, logp=self.logp,
+            ntemps=self.ntemps,
+            nwalkers=self.nwalkers,
+            dim=self.ndim,
+            logl=self.logl,
+            logp=self.logp,
             logpargs=(self.theta_prior, self.theta_keys, self.search),
-            loglargs=(self.search,), betas=self.betas, a=proposal_scale_factor)
+            loglargs=(self.search,),
+            betas=self.betas,
+            a=proposal_scale_factor,
+        )
 
         p0 = self._generate_initial_p0()
         p0 = self._apply_corrections_to_p0(p0)
@@ -534,15 +601,16 @@ class MCMCSearch(core.BaseSearchClass):
         # Run initialisation steps if required
         ninit_steps = len(self.nsteps) - 2
         for j, n in enumerate(self.nsteps[:-2]):
-            logging.info('Running {}/{} initialisation with {} steps'.format(
-                j, ninit_steps, n))
+            logging.info(
+                "Running {}/{} initialisation with {} steps".format(j, ninit_steps, n)
+            )
             sampler = self._run_sampler(sampler, p0, nburn=n, window=window)
             if create_plots:
-                fig, axes = self._plot_walkers(sampler,
-                                               **kwargs)
+                fig, axes = self._plot_walkers(sampler, **kwargs)
                 fig.tight_layout()
-                fig.savefig('{}/{}_init_{}_walkers.png'.format(
-                    self.outdir, self.label, j))
+                fig.savefig(
+                    "{}/{}_init_{}_walkers.png".format(self.outdir, self.label, j)
+                )
 
             p0 = self._get_new_p0(sampler)
             p0 = self._apply_corrections_to_p0(p0)
@@ -554,18 +622,16 @@ class MCMCSearch(core.BaseSearchClass):
         else:
             nburn = 0
         nprod = self.nsteps[-1]
-        logging.info('Running final burn and prod with {} steps'.format(
-            nburn+nprod))
+        logging.info("Running final burn and prod with {} steps".format(nburn + nprod))
         sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod)
 
         if create_plots:
             try:
                 fig, axes = self._plot_walkers(sampler, nprod=nprod, **kwargs)
                 fig.tight_layout()
-                fig.savefig('{}/{}_walkers.png'.format(self.outdir, self.label))
+                fig.savefig("{}/{}_walkers.png".format(self.outdir, self.label))
             except RuntimeError as e:
-                logging.warning("Failed to save walker plots due to Erro {}"
-                                .format(e))
+                logging.warning("Failed to save walker plots due to Erro {}".format(e))
 
         samples = sampler.chain[0, :, nburn:, :].reshape((-1, self.ndim))
         lnprobs = sampler.logprobability[0, :, nburn:].reshape((-1))
@@ -576,8 +642,9 @@ class MCMCSearch(core.BaseSearchClass):
         self.lnprobs = lnprobs
         self.lnlikes = lnlikes
         self.all_lnlikelihood = all_lnlikelihood
-        self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood,
-                        sampler.chain)
+        self._save_data(
+            sampler, samples, lnprobs, lnlikes, all_lnlikelihood, sampler.chain
+        )
         return sampler
 
     def _get_rescale_multiplier_for_key(self, key):
@@ -590,15 +657,15 @@ class MCMCSearch(core.BaseSearchClass):
         if key not in self.transform_dictionary:
             return 1
 
-        if 'multiplier' in self.transform_dictionary[key]:
-            val = self.transform_dictionary[key]['multiplier']
+        if "multiplier" in self.transform_dictionary[key]:
+            val = self.transform_dictionary[key]["multiplier"]
             if type(val) == str:
                 if hasattr(self, val):
                     multiplier = getattr(
-                        self, self.transform_dictionary[key]['multiplier'])
+                        self, self.transform_dictionary[key]["multiplier"]
+                    )
                 else:
-                    raise ValueError(
-                        "multiplier {} not a class attribute".format(val))
+                    raise ValueError("multiplier {} not a class attribute".format(val))
             else:
                 multiplier = val
         else:
@@ -615,15 +682,15 @@ class MCMCSearch(core.BaseSearchClass):
         if key not in self.transform_dictionary:
             return 0
 
-        if 'subtractor' in self.transform_dictionary[key]:
-            val = self.transform_dictionary[key]['subtractor']
+        if "subtractor" in self.transform_dictionary[key]:
+            val = self.transform_dictionary[key]["subtractor"]
             if type(val) == str:
                 if hasattr(self, val):
                     subtractor = getattr(
-                        self, self.transform_dictionary[key]['subtractor'])
+                        self, self.transform_dictionary[key]["subtractor"]
+                    )
                 else:
-                    raise ValueError(
-                        "subtractor {} not a class attribute".format(val))
+                    raise ValueError("subtractor {} not a class attribute".format(val))
             else:
                 subtractor = val
         else:
@@ -651,27 +718,36 @@ class MCMCSearch(core.BaseSearchClass):
         for key in self.theta_keys:
             label = None
             s = self.symbol_dictionary[key]
-            s.replace('_{glitch}', r'_\textrm{glitch}')
+            s.replace("_{glitch}", r"_\textrm{glitch}")
             u = self.unit_dictionary[key]
             if key in self.transform_dictionary:
-                if 'symbol' in self.transform_dictionary[key]:
-                    s = self.transform_dictionary[key]['symbol']
-                if 'label' in self.transform_dictionary[key]:
-                    label = self.transform_dictionary[key]['label']
-                if 'unit' in self.transform_dictionary[key]:
-                    u = self.transform_dictionary[key]['unit']
+                if "symbol" in self.transform_dictionary[key]:
+                    s = self.transform_dictionary[key]["symbol"]
+                if "label" in self.transform_dictionary[key]:
+                    label = self.transform_dictionary[key]["label"]
+                if "unit" in self.transform_dictionary[key]:
+                    u = self.transform_dictionary[key]["unit"]
             if label is None:
                 if newline_units:
-                    label = '{} \n [{}]'.format(s, u)
+                    label = "{} \n [{}]".format(s, u)
                 else:
-                    label = '{} [{}]'.format(s, u)
+                    label = "{} [{}]".format(s, u)
             labels.append(label)
         return labels
 
-    def plot_corner(self, figsize=(7, 7), add_prior=False, nstds=None,
-                    label_offset=0.4, dpi=300, rc_context={},
-                    tglitch_ratio=False, fig_and_axes=None, save_fig=True,
-                    **kwargs):
+    def plot_corner(
+        self,
+        figsize=(7, 7),
+        add_prior=False,
+        nstds=None,
+        label_offset=0.4,
+        dpi=300,
+        rc_context={},
+        tglitch_ratio=False,
+        fig_and_axes=None,
+        save_fig=True,
+        **kwargs
+    ):
         """ Generate a corner plot of the posterior
 
         Using the `corner` package (https://pypi.python.org/pypi/corner/),
@@ -713,9 +789,9 @@ class MCMCSearch(core.BaseSearchClass):
 
         """
 
-        if 'truths' in kwargs and len(kwargs['truths']) != self.ndim:
-            logging.warning('len(Truths) != ndim, Truths will be ignored')
-            kwargs['truths'] = None
+        if "truths" in kwargs and len(kwargs["truths"]) != self.ndim:
+            logging.warning("len(Truths) != ndim, Truths will be ignored")
+            kwargs["truths"] = None
 
         if self.ndim < 2:
             with plt.rc_context(rc_context):
@@ -723,17 +799,15 @@ class MCMCSearch(core.BaseSearchClass):
                     fig, ax = plt.subplots(figsize=figsize)
                 else:
                     fig, ax = fig_and_axes
-                ax.hist(self.samples, bins=50, histtype='stepfilled')
+                ax.hist(self.samples, bins=50, histtype="stepfilled")
                 ax.set_xlabel(self.theta_symbols[0])
 
-            fig.savefig('{}/{}_corner.png'.format(
-                self.outdir, self.label), dpi=dpi)
+            fig.savefig("{}/{}_corner.png".format(self.outdir, self.label), dpi=dpi)
             return
 
         with plt.rc_context(rc_context):
             if fig_and_axes is None:
-                fig, axes = plt.subplots(self.ndim, self.ndim,
-                                         figsize=figsize)
+                fig, axes = plt.subplots(self.ndim, self.ndim, figsize=figsize)
             else:
                 fig, axes = fig_and_axes
 
@@ -744,41 +818,42 @@ class MCMCSearch(core.BaseSearchClass):
 
             if tglitch_ratio:
                 for j, k in enumerate(self.theta_keys):
-                    if k == 'tglitch':
+                    if k == "tglitch":
                         s = samples_plt[:, j]
-                        samples_plt[:, j] = (
-                            s - self.minStartTime)/(
-                                self.maxStartTime - self.minStartTime)
-                        labels[j] = r'$R_{\textrm{glitch}}$'
+                        samples_plt[:, j] = (s - self.minStartTime) / (
+                            self.maxStartTime - self.minStartTime
+                        )
+                        labels[j] = r"$R_{\textrm{glitch}}$"
 
-            if type(nstds) is int and 'range' not in kwargs:
+            if type(nstds) is int and "range" not in kwargs:
                 _range = []
                 for j, s in enumerate(samples_plt.T):
                     median = np.median(s)
                     std = np.std(s)
-                    _range.append((median - nstds*std, median + nstds*std))
-            elif 'range' in kwargs:
-                _range = kwargs.pop('range')
+                    _range.append((median - nstds * std, median + nstds * std))
+            elif "range" in kwargs:
+                _range = kwargs.pop("range")
             else:
                 _range = None
 
-            hist_kwargs = kwargs.pop('hist_kwargs', dict())
-            if 'normed' not in hist_kwargs:
-                hist_kwargs['normed'] = True
-
-            fig_triangle = corner.corner(samples_plt,
-                                         labels=labels,
-                                         fig=fig,
-                                         bins=50,
-                                         max_n_ticks=4,
-                                         plot_contours=True,
-                                         plot_datapoints=True,
-                                         #label_kwargs={'fontsize': 12},
-                                         data_kwargs={'alpha': 0.1,
-                                                      'ms': 0.5},
-                                         range=_range,
-                                         hist_kwargs=hist_kwargs,
-                                         **kwargs)
+            hist_kwargs = kwargs.pop("hist_kwargs", dict())
+            if "normed" not in hist_kwargs:
+                hist_kwargs["normed"] = True
+
+            fig_triangle = corner.corner(
+                samples_plt,
+                labels=labels,
+                fig=fig,
+                bins=50,
+                max_n_ticks=4,
+                plot_contours=True,
+                plot_datapoints=True,
+                # label_kwargs={'fontsize': 12},
+                data_kwargs={"alpha": 0.1, "ms": 0.5},
+                range=_range,
+                hist_kwargs=hist_kwargs,
+                **kwargs
+            )
 
             axes_list = fig_triangle.get_axes()
             axes = np.array(axes_list).reshape(self.ndim, self.ndim)
@@ -792,11 +867,11 @@ class MCMCSearch(core.BaseSearchClass):
                 ax.set_rasterization_zorder(-10)
 
                 for tick in ax.xaxis.get_major_ticks():
-                    #tick.label.set_fontsize(8)
-                    tick.label.set_rotation('horizontal')
+                    # tick.label.set_fontsize(8)
+                    tick.label.set_rotation("horizontal")
                 for tick in ax.yaxis.get_major_ticks():
-                    #tick.label.set_fontsize(8)
-                    tick.label.set_rotation('vertical')
+                    # tick.label.set_fontsize(8)
+                    tick.label.set_rotation("vertical")
 
             plt.tight_layout(h_pad=0.0, w_pad=0.0)
             fig.subplots_adjust(hspace=0.05, wspace=0.05)
@@ -805,13 +880,13 @@ class MCMCSearch(core.BaseSearchClass):
                 self._add_prior_to_corner(axes, self.samples, add_prior)
 
             if save_fig:
-                fig_triangle.savefig('{}/{}_corner.png'.format(
-                    self.outdir, self.label), dpi=dpi)
+                fig_triangle.savefig(
+                    "{}/{}_corner.png".format(self.outdir, self.label), dpi=dpi
+                )
             else:
                 return fig, axes
 
-    def plot_chainconsumer(
-            self, save_fig=True, label_offset=0.25, dpi=300, **kwargs):
+    def plot_chainconsumer(self, save_fig=True, label_offset=0.25, dpi=300, **kwargs):
         """ Generate a corner plot of the posterior using chainconsumer
 
         Parameters
@@ -823,9 +898,9 @@ class MCMCSearch(core.BaseSearchClass):
 
         """
 
-        if 'truths' in kwargs and len(kwargs['truths']) != self.ndim:
-            logging.warning('len(Truths) != ndim, Truths will be ignored')
-            kwargs['truths'] = None
+        if "truths" in kwargs and len(kwargs["truths"]) != self.ndim:
+            logging.warning("len(Truths) != ndim, Truths will be ignored")
+            kwargs["truths"] = None
 
         samples_plt = copy.copy(self.samples)
         labels = self._get_labels(newline_units=True)
@@ -833,6 +908,7 @@ class MCMCSearch(core.BaseSearchClass):
         samples_plt = self._scale_samples(samples_plt, self.theta_keys)
 
         import chainconsumer
+
         c = chainconsumer.ChainConsumer()
         c.add_chain(samples_plt, parameters=labels)
         c.configure(smooth=0, summary=False, sigma2d=True)
@@ -849,10 +925,10 @@ class MCMCSearch(core.BaseSearchClass):
             ax.set_rasterized(True)
             ax.set_rasterization_zorder(-10)
 
-            #for tick in ax.xaxis.get_major_ticks():
+            # for tick in ax.xaxis.get_major_ticks():
             #    #tick.label.set_fontsize(8)
             #    tick.label.set_rotation('horizontal')
-            #for tick in ax.yaxis.get_major_ticks():
+            # for tick in ax.yaxis.get_major_ticks():
             #    #tick.label.set_fontsize(8)
             #    tick.label.set_rotation('vertical')
 
@@ -860,8 +936,7 @@ class MCMCSearch(core.BaseSearchClass):
             fig.subplots_adjust(hspace=0.05, wspace=0.05)
 
         if save_fig:
-            fig.savefig('{}/{}_corner.png'.format(
-                self.outdir, self.label), dpi=dpi)
+            fig.savefig("{}/{}_corner.png".format(self.outdir, self.label), dpi=dpi)
         else:
             return fig
 
@@ -870,20 +945,23 @@ class MCMCSearch(core.BaseSearchClass):
             ax = axes[i][i]
             s = samples[:, i]
             lnprior = self._generic_lnprior(**self.theta_prior[key])
-            if add_prior == 'full' and self.theta_prior[key]['type'] == 'unif':
-                lower = self.theta_prior[key]['lower']
-                upper = self.theta_prior[key]['upper']
-                r = upper-lower
-                xlim = [lower-0.05*r, upper+0.05*r]
+            if add_prior == "full" and self.theta_prior[key]["type"] == "unif":
+                lower = self.theta_prior[key]["lower"]
+                upper = self.theta_prior[key]["upper"]
+                r = upper - lower
+                xlim = [lower - 0.05 * r, upper + 0.05 * r]
                 x = np.linspace(xlim[0], xlim[1], 1000)
             else:
                 xlim = ax.get_xlim()
                 x = np.linspace(s.min(), s.max(), 1000)
             multiplier = self._get_rescale_multiplier_for_key(key)
             subtractor = self._get_rescale_subtractor_for_key(key)
-            ax.plot((x-subtractor)*multiplier,
-                    [np.exp(lnprior(xi)) for xi in x], '-C3',
-                    label='prior')
+            ax.plot(
+                (x - subtractor) * multiplier,
+                [np.exp(lnprior(xi)) for xi in x],
+                "-C3",
+                label="prior",
+            )
 
             for j in range(i, self.ndim):
                 axes[j][i].set_xlim(xlim[0], xlim[1])
@@ -892,51 +970,52 @@ class MCMCSearch(core.BaseSearchClass):
 
     def plot_prior_posterior(self, normal_stds=2):
         """ Plot the posterior in the context of the prior """
-        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4*self.ndim))
+        fig, axes = plt.subplots(nrows=self.ndim, figsize=(8, 4 * self.ndim))
         N = 1000
         from scipy.stats import gaussian_kde
 
         for i, (ax, key) in enumerate(zip(axes, self.theta_keys)):
             prior_dict = self.theta_prior[key]
             prior_func = self._generic_lnprior(**prior_dict)
-            if prior_dict['type'] == 'unif':
-                x = np.linspace(prior_dict['lower'], prior_dict['upper'], N)
+            if prior_dict["type"] == "unif":
+                x = np.linspace(prior_dict["lower"], prior_dict["upper"], N)
                 prior = prior_func(x)
                 prior[0] = 0
                 prior[-1] = 0
-            elif prior_dict['type'] == 'log10unif':
-                upper = prior_dict['log10upper']
-                lower = prior_dict['log10lower']
+            elif prior_dict["type"] == "log10unif":
+                upper = prior_dict["log10upper"]
+                lower = prior_dict["log10lower"]
                 x = np.linspace(lower, upper, N)
                 prior = [prior_func(xi) for xi in x]
-            elif prior_dict['type'] == 'norm':
-                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
-                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
+            elif prior_dict["type"] == "norm":
+                lower = prior_dict["loc"] - normal_stds * prior_dict["scale"]
+                upper = prior_dict["loc"] + normal_stds * prior_dict["scale"]
                 x = np.linspace(lower, upper, N)
                 prior = prior_func(x)
-            elif prior_dict['type'] == 'halfnorm':
-                lower = prior_dict['loc']
-                upper = prior_dict['loc'] + normal_stds * prior_dict['scale']
+            elif prior_dict["type"] == "halfnorm":
+                lower = prior_dict["loc"]
+                upper = prior_dict["loc"] + normal_stds * prior_dict["scale"]
                 x = np.linspace(lower, upper, N)
                 prior = [prior_func(xi) for xi in x]
-            elif prior_dict['type'] == 'neghalfnorm':
-                upper = prior_dict['loc']
-                lower = prior_dict['loc'] - normal_stds * prior_dict['scale']
+            elif prior_dict["type"] == "neghalfnorm":
+                upper = prior_dict["loc"]
+                lower = prior_dict["loc"] - normal_stds * prior_dict["scale"]
                 x = np.linspace(lower, upper, N)
                 prior = [prior_func(xi) for xi in x]
             else:
-                raise ValueError('Not implemented for prior type {}'.format(
-                    prior_dict['type']))
-            priorln = ax.plot(x, prior, 'C3', label='prior')
+                raise ValueError(
+                    "Not implemented for prior type {}".format(prior_dict["type"])
+                )
+            priorln = ax.plot(x, prior, "C3", label="prior")
             ax.set_xlabel(self.theta_symbols[i])
 
             s = self.samples[:, i]
-            while len(s) > 10**4:
+            while len(s) > 10 ** 4:
                 # random downsample to avoid slow calculation of kde
-                s = np.random.choice(s, size=int(len(s)/2.))
+                s = np.random.choice(s, size=int(len(s) / 2.0))
             kde = gaussian_kde(s)
             ax2 = ax.twinx()
-            postln = ax2.plot(x, kde.pdf(x), 'k', label='posterior')
+            postln = ax2.plot(x, kde.pdf(x), "k", label="posterior")
             ax2.set_yticklabels([])
             ax.set_yticklabels([])
 
@@ -944,8 +1023,7 @@ class MCMCSearch(core.BaseSearchClass):
         labs = [l.get_label() for l in lns]
         axes[0].legend(lns, labs, loc=1, framealpha=0.8)
 
-        fig.savefig('{}/{}_prior_posterior.png'.format(
-            self.outdir, self.label))
+        fig.savefig("{}/{}_prior_posterior.png".format(self.outdir, self.label))
 
     def plot_cumulative_max(self, **kwargs):
         """ Plot the cumulative twoF for the maximum posterior estimate
@@ -957,23 +1035,42 @@ class MCMCSearch(core.BaseSearchClass):
             if key not in d:
                 d[key] = val
 
-        if 'add_pfs' in kwargs:
+        if "add_pfs" in kwargs:
             self.generate_loudest()
 
-        if hasattr(self, 'search') is False:
+        if hasattr(self, "search") is False:
             self._initiate_search_object()
         if self.binary is False:
             self.search.plot_twoF_cumulative(
-                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
-                Alpha=d['Alpha'], Delta=d['Delta'],
-                tstart=self.minStartTime, tend=self.maxStartTime,
-                **kwargs)
+                self.label,
+                self.outdir,
+                F0=d["F0"],
+                F1=d["F1"],
+                F2=d["F2"],
+                Alpha=d["Alpha"],
+                Delta=d["Delta"],
+                tstart=self.minStartTime,
+                tend=self.maxStartTime,
+                **kwargs
+            )
         else:
             self.search.plot_twoF_cumulative(
-                self.label, self.outdir, F0=d['F0'], F1=d['F1'], F2=d['F2'],
-                Alpha=d['Alpha'], Delta=d['Delta'], asini=d['asini'],
-                period=d['period'], ecc=d['ecc'], argp=d['argp'], tp=d['argp'],
-                tstart=self.minStartTime, tend=self.maxStartTime, **kwargs)
+                self.label,
+                self.outdir,
+                F0=d["F0"],
+                F1=d["F1"],
+                F2=d["F2"],
+                Alpha=d["Alpha"],
+                Delta=d["Delta"],
+                asini=d["asini"],
+                period=d["period"],
+                ecc=d["ecc"],
+                argp=d["argp"],
+                tp=d["argp"],
+                tstart=self.minStartTime,
+                tend=self.maxStartTime,
+                **kwargs
+            )
 
     def _generic_lnprior(self, **kwargs):
         """ Return a lambda function of the pdf
@@ -990,13 +1087,13 @@ class MCMCSearch(core.BaseSearchClass):
             below = x > a
             if type(above) is not np.ndarray:
                 if above and below:
-                    return -np.log(b-a)
+                    return -np.log(b - a)
                 else:
                     return -np.inf
             else:
                 idxs = np.array([all(tup) for tup in zip(above, below)])
                 p = np.zeros(len(x)) - np.inf
-                p[idxs] = -np.log(b-a)
+                p[idxs] = -np.log(b - a)
                 return p
 
         def log_of_log10unif(x, log10lower, log10upper):
@@ -1005,80 +1102,100 @@ class MCMCSearch(core.BaseSearchClass):
             below = log10x > log10lower
             if type(above) is not np.ndarray:
                 if above and below:
-                    return -np.log(x*np.log(10)*(log10upper-log10lower))
+                    return -np.log(x * np.log(10) * (log10upper - log10lower))
                 else:
                     return -np.inf
             else:
                 idxs = np.array([all(tup) for tup in zip(above, below)])
                 p = np.zeros(len(x)) - np.inf
-                p[idxs] = -np.log(x*np.log(10)*(log10upper-log10lower))
+                p[idxs] = -np.log(x * np.log(10) * (log10upper - log10lower))
                 return p
 
         def log_of_halfnorm(x, loc, scale):
             if x < loc:
                 return -np.inf
             else:
-                return -0.5*((x-loc)**2/scale**2+np.log(0.5*np.pi*scale**2))
+                return -0.5 * (
+                    (x - loc) ** 2 / scale ** 2 + np.log(0.5 * np.pi * scale ** 2)
+                )
 
         def cauchy(x, x0, gamma):
-            return 1.0/(np.pi*gamma*(1+((x-x0)/gamma)**2))
+            return 1.0 / (np.pi * gamma * (1 + ((x - x0) / gamma) ** 2))
 
         def exp(x, x0, gamma):
             if x > x0:
-                return np.log(gamma) - gamma*(x - x0)
+                return np.log(gamma) - gamma * (x - x0)
             else:
                 return -np.inf
 
-        if kwargs['type'] == 'unif':
-            return lambda x: log_of_unif(x, kwargs['lower'], kwargs['upper'])
-        if kwargs['type'] == 'log10unif':
+        if kwargs["type"] == "unif":
+            return lambda x: log_of_unif(x, kwargs["lower"], kwargs["upper"])
+        if kwargs["type"] == "log10unif":
             return lambda x: log_of_log10unif(
-                x, kwargs['log10lower'], kwargs['log10upper'])
-        elif kwargs['type'] == 'halfnorm':
-            return lambda x: log_of_halfnorm(x, kwargs['loc'], kwargs['scale'])
-        elif kwargs['type'] == 'neghalfnorm':
-            return lambda x: log_of_halfnorm(
-                -x, kwargs['loc'], kwargs['scale'])
-        elif kwargs['type'] == 'norm':
-            return lambda x: -0.5*((x - kwargs['loc'])**2/kwargs['scale']**2
-                                   + np.log(2*np.pi*kwargs['scale']**2))
+                x, kwargs["log10lower"], kwargs["log10upper"]
+            )
+        elif kwargs["type"] == "halfnorm":
+            return lambda x: log_of_halfnorm(x, kwargs["loc"], kwargs["scale"])
+        elif kwargs["type"] == "neghalfnorm":
+            return lambda x: log_of_halfnorm(-x, kwargs["loc"], kwargs["scale"])
+        elif kwargs["type"] == "norm":
+            return lambda x: -0.5 * (
+                (x - kwargs["loc"]) ** 2 / kwargs["scale"] ** 2
+                + np.log(2 * np.pi * kwargs["scale"] ** 2)
+            )
         else:
             logging.info("kwargs:", kwargs)
             raise ValueError("Print unrecognise distribution")
 
     def _generate_rv(self, **kwargs):
-        dist_type = kwargs.pop('type')
+        dist_type = kwargs.pop("type")
         if dist_type == "unif":
-            return np.random.uniform(low=kwargs['lower'], high=kwargs['upper'])
+            return np.random.uniform(low=kwargs["lower"], high=kwargs["upper"])
         if dist_type == "log10unif":
-            return 10**(np.random.uniform(low=kwargs['log10lower'],
-                                          high=kwargs['log10upper']))
+            return 10 ** (
+                np.random.uniform(low=kwargs["log10lower"], high=kwargs["log10upper"])
+            )
         if dist_type == "norm":
-            return np.random.normal(loc=kwargs['loc'], scale=kwargs['scale'])
+            return np.random.normal(loc=kwargs["loc"], scale=kwargs["scale"])
         if dist_type == "halfnorm":
-            return np.abs(np.random.normal(loc=kwargs['loc'],
-                                           scale=kwargs['scale']))
+            return np.abs(np.random.normal(loc=kwargs["loc"], scale=kwargs["scale"]))
         if dist_type == "neghalfnorm":
-            return -1 * np.abs(np.random.normal(loc=kwargs['loc'],
-                                                scale=kwargs['scale']))
+            return -1 * np.abs(
+                np.random.normal(loc=kwargs["loc"], scale=kwargs["scale"])
+            )
         if dist_type == "lognorm":
-            return np.random.lognormal(
-                mean=kwargs['loc'], sigma=kwargs['scale'])
+            return np.random.lognormal(mean=kwargs["loc"], sigma=kwargs["scale"])
         else:
             raise ValueError("dist_type {} unknown".format(dist_type))
 
-    def _plot_walkers(self, sampler, symbols=None, alpha=0.8, color="k",
-                      temp=0, lw=0.1, nprod=0, add_det_stat_burnin=False,
-                      fig=None, axes=None, xoffset=0, plot_det_stat=False,
-                      context='ggplot', labelpad=5):
+    def _plot_walkers(
+        self,
+        sampler,
+        symbols=None,
+        alpha=0.8,
+        color="k",
+        temp=0,
+        lw=0.1,
+        nprod=0,
+        add_det_stat_burnin=False,
+        fig=None,
+        axes=None,
+        xoffset=0,
+        plot_det_stat=False,
+        context="ggplot",
+        labelpad=5,
+    ):
         """ Plot all the chains from a sampler """
 
         if symbols is None:
             symbols = self._get_labels()
         if context not in plt.style.available:
-            raise ValueError((
-                'The requested context {} is not available; please select a'
-                ' context from `plt.style.available`').format(context))
+            raise ValueError(
+                (
+                    "The requested context {} is not available; please select a"
+                    " context from `plt.style.available`"
+                ).format(context)
+            )
 
         if np.ndim(axes) > 1:
             axes = axes.flatten()
@@ -1092,11 +1209,14 @@ class MCMCSearch(core.BaseSearchClass):
             if temp < ntemps:
                 logging.info("Plotting temperature {} chains".format(temp))
             else:
-                raise ValueError(("Requested temperature {} outside of"
-                                  "available range").format(temp))
+                raise ValueError(
+                    ("Requested temperature {} outside of" "available range").format(
+                        temp
+                    )
+                )
             chain = sampler.chain[temp, :, :, :].copy()
 
-        samples = chain.reshape((nwalkers*nsteps, ndim))
+        samples = chain.reshape((nwalkers * nsteps, ndim))
         samples = self._scale_samples(samples, self.theta_keys)
         chain = chain.reshape((nwalkers, nsteps, ndim))
 
@@ -1105,110 +1225,128 @@ class MCMCSearch(core.BaseSearchClass):
         else:
             extra_subplots = 0
         with plt.style.context((context)):
-            plt.rcParams['text.usetex'] = True
+            plt.rcParams["text.usetex"] = True
             if fig is None and axes is None:
-                fig = plt.figure(figsize=(4, 3.0*ndim))
-                ax = fig.add_subplot(ndim+extra_subplots, 1, 1)
-                axes = [ax] + [fig.add_subplot(ndim+extra_subplots, 1, i)
-                               for i in range(2, ndim+1)]
+                fig = plt.figure(figsize=(4, 3.0 * ndim))
+                ax = fig.add_subplot(ndim + extra_subplots, 1, 1)
+                axes = [ax] + [
+                    fig.add_subplot(ndim + extra_subplots, 1, i)
+                    for i in range(2, ndim + 1)
+                ]
 
             idxs = np.arange(chain.shape[1])
             burnin_idx = chain.shape[1] - nprod
-            #if hasattr(self, 'convergence_idx'):
+            # if hasattr(self, 'convergence_idx'):
             #    last_idx = self.convergence_idx
-            #else:
+            # else:
             last_idx = burnin_idx
             if ndim > 1:
                 for i in range(ndim):
-                    axes[i].ticklabel_format(useOffset=False, axis='y')
+                    axes[i].ticklabel_format(useOffset=False, axis="y")
                     cs = chain[:, :, i].T
                     if burnin_idx > 0:
-                        axes[i].plot(xoffset+idxs[:last_idx+1],
-                                     cs[:last_idx+1],
-                                     color="C3", alpha=alpha,
-                                     lw=lw)
-                        axes[i].axvline(xoffset+last_idx,
-                                        color='k', ls='--', lw=0.5)
-                    axes[i].plot(xoffset+idxs[burnin_idx:],
-                                 cs[burnin_idx:],
-                                 color="k", alpha=alpha, lw=lw)
-
-                    axes[i].set_xlim(0, xoffset+idxs[-1])
+                        axes[i].plot(
+                            xoffset + idxs[: last_idx + 1],
+                            cs[: last_idx + 1],
+                            color="C3",
+                            alpha=alpha,
+                            lw=lw,
+                        )
+                        axes[i].axvline(xoffset + last_idx, color="k", ls="--", lw=0.5)
+                    axes[i].plot(
+                        xoffset + idxs[burnin_idx:],
+                        cs[burnin_idx:],
+                        color="k",
+                        alpha=alpha,
+                        lw=lw,
+                    )
+
+                    axes[i].set_xlim(0, xoffset + idxs[-1])
                     if symbols:
                         axes[i].set_ylabel(symbols[i], labelpad=labelpad)
-                        #if subtractions[i] == 0:
+                        # if subtractions[i] == 0:
                         #    axes[i].set_ylabel(symbols[i], labelpad=labelpad)
-                        #else:
+                        # else:
                         #    axes[i].set_ylabel(
                         #        symbols[i]+'$-$'+symbols[i]+'$^\mathrm{s}$',
                         #        labelpad=labelpad)
 
-#                    if hasattr(self, 'convergence_diagnostic'):
-#                        ax = axes[i].twinx()
-#                        axes[i].set_zorder(ax.get_zorder()+1)
-#                        axes[i].patch.set_visible(False)
-#                        c_x = np.array(self.convergence_diagnosticx)
-#                        c_y = np.array(self.convergence_diagnostic)
-#                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
-#                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-C0',
-#                                zorder=-10)
-#                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
-#                                zorder=-10)
-#                        if self.convergence_test_type == 'autocorr':
-#                            ax.set_ylabel(r'$\tau_\mathrm{exp}$')
-#                        elif self.convergence_test_type == 'GR':
-#                            ax.set_ylabel('PSRF')
-#                        ax.ticklabel_format(useOffset=False)
+            #                    if hasattr(self, 'convergence_diagnostic'):
+            #                        ax = axes[i].twinx()
+            #                        axes[i].set_zorder(ax.get_zorder()+1)
+            #                        axes[i].patch.set_visible(False)
+            #                        c_x = np.array(self.convergence_diagnosticx)
+            #                        c_y = np.array(self.convergence_diagnostic)
+            #                        break_idx = np.argmin(np.abs(c_x - burnin_idx))
+            #                        ax.plot(c_x[:break_idx], c_y[:break_idx, i], '-C0',
+            #                                zorder=-10)
+            #                        ax.plot(c_x[break_idx:], c_y[break_idx:, i], '-C0',
+            #                                zorder=-10)
+            #                        if self.convergence_test_type == 'autocorr':
+            #                            ax.set_ylabel(r'$\tau_\mathrm{exp}$')
+            #                        elif self.convergence_test_type == 'GR':
+            #                            ax.set_ylabel('PSRF')
+            #                        ax.ticklabel_format(useOffset=False)
             else:
-                axes[0].ticklabel_format(useOffset=False, axis='y')
+                axes[0].ticklabel_format(useOffset=False, axis="y")
                 cs = chain[:, :, temp].T
                 if burnin_idx:
-                    axes[0].plot(idxs[:burnin_idx], cs[:burnin_idx],
-                                 color="C3", alpha=alpha, lw=lw)
-                axes[0].plot(idxs[burnin_idx:], cs[burnin_idx:], color="k",
-                             alpha=alpha, lw=lw)
+                    axes[0].plot(
+                        idxs[:burnin_idx],
+                        cs[:burnin_idx],
+                        color="C3",
+                        alpha=alpha,
+                        lw=lw,
+                    )
+                axes[0].plot(
+                    idxs[burnin_idx:], cs[burnin_idx:], color="k", alpha=alpha, lw=lw
+                )
                 if symbols:
                     axes[0].set_ylabel(symbols[0], labelpad=labelpad)
 
-            axes[-1].set_xlabel(r'$\textrm{Number of steps}$', labelpad=0.2)
+            axes[-1].set_xlabel(r"$\textrm{Number of steps}$", labelpad=0.2)
 
             if plot_det_stat:
                 if len(axes) == ndim:
-                    axes.append(fig.add_subplot(ndim+1, 1, ndim+1))
+                    axes.append(fig.add_subplot(ndim + 1, 1, ndim + 1))
 
                 lnl = sampler.loglikelihood[temp, :, :]
                 if burnin_idx and add_det_stat_burnin:
                     burn_in_vals = lnl[:, :burnin_idx].flatten()
                     try:
-                        twoF_burnin = (burn_in_vals[~np.isnan(burn_in_vals)]
-                                       - self.likelihoodcoef)
-                        axes[-1].hist(twoF_burnin, bins=50, histtype='step',
-                                      color='C3')
+                        twoF_burnin = (
+                            burn_in_vals[~np.isnan(burn_in_vals)] - self.likelihoodcoef
+                        )
+                        axes[-1].hist(twoF_burnin, bins=50, histtype="step", color="C3")
                     except ValueError:
-                        logging.info('Det. Stat. hist failed, most likely all '
-                                     'values where the same')
+                        logging.info(
+                            "Det. Stat. hist failed, most likely all "
+                            "values where the same"
+                        )
                         pass
                 else:
                     twoF_burnin = []
                 prod_vals = lnl[:, burnin_idx:].flatten()
                 try:
-                    twoF = prod_vals[~np.isnan(prod_vals)]-self.likelihoodcoef
-                    axes[-1].hist(twoF, bins=50, histtype='step', color='k')
+                    twoF = prod_vals[~np.isnan(prod_vals)] - self.likelihoodcoef
+                    axes[-1].hist(twoF, bins=50, histtype="step", color="k")
                 except ValueError:
-                    logging.info('Det. Stat. hist failed, most likely all '
-                                 'values where the same')
+                    logging.info(
+                        "Det. Stat. hist failed, most likely all "
+                        "values where the same"
+                    )
                     pass
                 if self.BSGL:
-                    axes[-1].set_xlabel(r'$\mathcal{B}_\mathrm{S/GL}$')
+                    axes[-1].set_xlabel(r"$\mathcal{B}_\mathrm{S/GL}$")
                 else:
-                    axes[-1].set_xlabel(r'$\widetilde{2\mathcal{F}}$')
-                axes[-1].set_ylabel(r'$\textrm{Counts}$')
+                    axes[-1].set_xlabel(r"$\widetilde{2\mathcal{F}}$")
+                axes[-1].set_ylabel(r"$\textrm{Counts}$")
                 combined_vals = np.append(twoF_burnin, twoF)
                 if len(combined_vals) > 0:
                     minv = np.min(combined_vals)
                     maxv = np.max(combined_vals)
-                    Range = abs(maxv-minv)
-                    axes[-1].set_xlim(minv-0.1*Range, maxv+0.1*Range)
+                    Range = abs(maxv - minv)
+                    axes[-1].set_xlim(minv - 0.1 * Range, maxv + 0.1 * Range)
 
                 xfmt = matplotlib.ticker.ScalarFormatter()
                 xfmt.set_powerlimits((-4, 4))
@@ -1222,30 +1360,46 @@ class MCMCSearch(core.BaseSearchClass):
 
     def _generate_scattered_p0(self, p):
         """ Generate a set of p0s scattered about p """
-        p0 = [[p + self.scatter_val * p * np.random.randn(self.ndim)
-               for i in range(self.nwalkers)]
-              for j in range(self.ntemps)]
+        p0 = [
+            [
+                p + self.scatter_val * p * np.random.randn(self.ndim)
+                for i in range(self.nwalkers)
+            ]
+            for j in range(self.ntemps)
+        ]
         return p0
 
     def _generate_initial_p0(self):
         """ Generate a set of init vals for the walkers """
 
         if type(self.theta_initial) == dict:
-            logging.info('Generate initial values from initial dictionary')
-            if hasattr(self, 'nglitch') and self.nglitch > 1:
-                raise ValueError('Initial dict not implemented for nglitch>1')
-            p0 = [[[self._generate_rv(**self.theta_initial[key])
-                    for key in self.theta_keys]
-                   for i in range(self.nwalkers)]
-                  for j in range(self.ntemps)]
+            logging.info("Generate initial values from initial dictionary")
+            if hasattr(self, "nglitch") and self.nglitch > 1:
+                raise ValueError("Initial dict not implemented for nglitch>1")
+            p0 = [
+                [
+                    [
+                        self._generate_rv(**self.theta_initial[key])
+                        for key in self.theta_keys
+                    ]
+                    for i in range(self.nwalkers)
+                ]
+                for j in range(self.ntemps)
+            ]
         elif self.theta_initial is None:
-            logging.info('Generate initial values from prior dictionary')
-            p0 = [[[self._generate_rv(**self.theta_prior[key])
-                    for key in self.theta_keys]
-                   for i in range(self.nwalkers)]
-                  for j in range(self.ntemps)]
+            logging.info("Generate initial values from prior dictionary")
+            p0 = [
+                [
+                    [
+                        self._generate_rv(**self.theta_prior[key])
+                        for key in self.theta_keys
+                    ]
+                    for i in range(self.nwalkers)
+                ]
+                for j in range(self.ntemps)
+            ]
         else:
-            raise ValueError('theta_initial not understood')
+            raise ValueError("theta_initial not understood")
 
         return p0
 
@@ -1264,16 +1418,20 @@ class MCMCSearch(core.BaseSearchClass):
         # General warnings about the state of lnp
         if np.any(np.isnan(lnp)):
             logging.warning(
-                "Of {} lnprobs {} are nan".format(
-                    np.shape(lnp), np.sum(np.isnan(lnp))))
+                "Of {} lnprobs {} are nan".format(np.shape(lnp), np.sum(np.isnan(lnp)))
+            )
         if np.any(np.isposinf(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are +np.inf".format(
-                    np.shape(lnp), np.sum(np.isposinf(lnp))))
+                    np.shape(lnp), np.sum(np.isposinf(lnp))
+                )
+            )
         if np.any(np.isneginf(lnp)):
             logging.warning(
                 "Of {} lnprobs {} are -np.inf".format(
-                    np.shape(lnp), np.sum(np.isneginf(lnp))))
+                    np.shape(lnp), np.sum(np.isneginf(lnp))
+                )
+            )
 
         lnp_finite = copy.copy(lnp)
         lnp_finite[np.isinf(lnp)] = np.nan
@@ -1285,34 +1443,44 @@ class MCMCSearch(core.BaseSearchClass):
         twoF = self.logl(p, self.search)
         self.search.BSGL = self.BSGL
 
-        logging.info(('Gen. new p0 from pos {} which had det. stat.={:2.1f},'
-                      ' twoF={:2.1f} and lnp={:2.1f}')
-                     .format(idx[1], lnl[idx], twoF, lnp_finite[idx]))
+        logging.info(
+            (
+                "Gen. new p0 from pos {} which had det. stat.={:2.1f},"
+                " twoF={:2.1f} and lnp={:2.1f}"
+            ).format(idx[1], lnl[idx], twoF, lnp_finite[idx])
+        )
 
         return p0
 
     def _get_data_dictionary_to_save(self):
-        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
-                 ntemps=self.ntemps, theta_keys=self.theta_keys,
-                 theta_prior=self.theta_prior,
-                 log10beta_min=self.log10beta_min,
-                 BSGL=self.BSGL, minStartTime=self.minStartTime,
-                 maxStartTime=self.maxStartTime)
+        d = dict(
+            nsteps=self.nsteps,
+            nwalkers=self.nwalkers,
+            ntemps=self.ntemps,
+            theta_keys=self.theta_keys,
+            theta_prior=self.theta_prior,
+            log10beta_min=self.log10beta_min,
+            BSGL=self.BSGL,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+        )
         return d
 
-    def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood, 
-                   chain):
+    def _save_data(self, sampler, samples, lnprobs, lnlikes, all_lnlikelihood, chain):
         d = self._get_data_dictionary_to_save()
-        d['samples'] = samples
-        d['lnprobs'] = lnprobs
-        d['lnlikes'] = lnlikes
-        d['chain'] = chain
-        d['all_lnlikelihood'] = all_lnlikelihood
+        d["samples"] = samples
+        d["lnprobs"] = lnprobs
+        d["lnlikes"] = lnlikes
+        d["chain"] = chain
+        d["all_lnlikelihood"] = all_lnlikelihood
 
         if os.path.isfile(self.pickle_path):
-            logging.info('Saving backup of {} as {}.old'.format(
-                self.pickle_path, self.pickle_path))
-            os.rename(self.pickle_path, self.pickle_path+".old")
+            logging.info(
+                "Saving backup of {} as {}.old".format(
+                    self.pickle_path, self.pickle_path
+                )
+            )
+            os.rename(self.pickle_path, self.pickle_path + ".old")
         with open(self.pickle_path, "wb") as File:
             pickle.dump(d, File)
 
@@ -1324,26 +1492,27 @@ class MCMCSearch(core.BaseSearchClass):
 
     def _check_old_data_is_okay_to_use(self):
         if os.path.isfile(self.pickle_path) is False:
-            logging.info('No pickled data found')
+            logging.info("No pickled data found")
             return False
 
         if self.sftfilepattern is not None:
-            oldest_sft = min([os.path.getmtime(f) for f in
-                              self._get_list_of_matching_sfts()])
+            oldest_sft = min(
+                [os.path.getmtime(f) for f in self._get_list_of_matching_sfts()]
+            )
             if os.path.getmtime(self.pickle_path) < oldest_sft:
-                logging.info('Pickled data outdates sft files')
+                logging.info("Pickled data outdates sft files")
                 return False
 
         old_d = self.get_saved_data_dictionary().copy()
         new_d = self._get_data_dictionary_to_save().copy()
 
-        old_d.pop('samples')
-        old_d.pop('lnprobs')
-        old_d.pop('lnlikes')
-        old_d.pop('all_lnlikelihood')
-        old_d.pop('chain')
+        old_d.pop("samples")
+        old_d.pop("lnprobs")
+        old_d.pop("lnlikes")
+        old_d.pop("all_lnlikelihood")
+        old_d.pop("chain")
 
-        for key in 'minStartTime', 'maxStartTime':
+        for key in "minStartTime", "maxStartTime":
             if new_d[key] is None:
                 new_d[key] = old_d[key]
                 setattr(self, key, new_d[key])
@@ -1354,7 +1523,7 @@ class MCMCSearch(core.BaseSearchClass):
                 if new_d[key] != old_d[key]:
                     mod_keys.append((key, old_d[key], new_d[key]))
             else:
-                raise ValueError('Keys {} not in old dictionary'.format(key))
+                raise ValueError("Keys {} not in old dictionary".format(key))
 
         if len(mod_keys) == 0:
             return True
@@ -1363,7 +1532,7 @@ class MCMCSearch(core.BaseSearchClass):
             logging.info("Differences found in following keys:")
             for key in mod_keys:
                 if len(key) == 3:
-                    if np.isscalar(key[1]) or key[0] == 'nsteps':
+                    if np.isscalar(key[1]) or key[0] == "nsteps":
                         logging.info("    {} : {} -> {}".format(*key))
                     else:
                         logging.info("    " + key[0])
@@ -1380,37 +1549,37 @@ class MCMCSearch(core.BaseSearchClass):
 
         """
         if any(np.isposinf(self.lnlikes)):
-            logging.info('lnlike values contain positive infinite values')
+            logging.info("lnlike values contain positive infinite values")
         if any(np.isneginf(self.lnlikes)):
-            logging.info('lnlike values contain negative infinite values')
+            logging.info("lnlike values contain negative infinite values")
         if any(np.isnan(self.lnlikes)):
-            logging.info('lnlike values contain nan')
+            logging.info("lnlike values contain nan")
         idxs = np.isfinite(self.lnlikes)
         jmax = np.nanargmax(self.lnlikes[idxs])
         maxlogl = self.lnlikes[jmax]
         d = OrderedDict()
 
         if self.BSGL:
-            if hasattr(self, 'search') is False:
+            if hasattr(self, "search") is False:
                 self._initiate_search_object()
             p = self.samples[jmax]
             self.search.BSGL = False
             maxtwoF = self.logl(p, self.search)
             self.search.BSGL = self.BSGL
         else:
-            maxtwoF = (maxlogl - self.likelihoodcoef)*2
+            maxtwoF = (maxlogl - self.likelihoodcoef) * 2
 
         repeats = []
         for i, k in enumerate(self.theta_keys):
             if k in d and k not in repeats:
-                d[k+'_0'] = d[k]  # relabel the old key
+                d[k + "_0"] = d[k]  # relabel the old key
                 d.pop(k)
                 repeats.append(k)
             if k in repeats:
-                k = k + '_0'
+                k = k + "_0"
                 count = 1
                 while k in d:
-                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
+                    k = k.replace("_{}".format(count - 1), "_{}".format(count))
                     count += 1
             d[k] = self.samples[jmax][i]
         return d, maxtwoF
@@ -1421,20 +1590,20 @@ class MCMCSearch(core.BaseSearchClass):
         repeats = []
         for s, k in zip(self.samples.T, self.theta_keys):
             if k in d and k not in repeats:
-                d[k+'_0'] = d[k]  # relabel the old key
-                d[k+'_0_std'] = d[k+'_std']
+                d[k + "_0"] = d[k]  # relabel the old key
+                d[k + "_0_std"] = d[k + "_std"]
                 d.pop(k)
-                d.pop(k+'_std')
+                d.pop(k + "_std")
                 repeats.append(k)
             if k in repeats:
-                k = k + '_0'
+                k = k + "_0"
                 count = 1
                 while k in d:
-                    k = k.replace('_{}'.format(count-1), '_{}'.format(count))
+                    k = k.replace("_{}".format(count - 1), "_{}".format(count))
                     count += 1
 
             d[k] = np.median(s)
-            d[k+'_std'] = np.std(s)
+            d[k + "_std"] = np.std(s)
         return d
 
     def check_if_samples_are_railing(self, threshold=0.01):
@@ -1454,80 +1623,96 @@ class MCMCSearch(core.BaseSearchClass):
         return_flag = False
         for s, k in zip(self.samples.T, self.theta_keys):
             prior = self.theta_prior[k]
-            if prior['type'] == 'unif':
-                prior_range = prior['upper'] - prior['lower']
+            if prior["type"] == "unif":
+                prior_range = prior["upper"] - prior["lower"]
                 edges = []
                 fracs = []
-                for l in ['lower', 'upper']:
-                    bools = np.abs(s - prior[l])/prior_range < threshold
+                for l in ["lower", "upper"]:
+                    bools = np.abs(s - prior[l]) / prior_range < threshold
                     if np.any(bools):
                         edges.append(l)
-                        fracs.append(str(100*float(np.sum(bools))/len(bools)))
+                        fracs.append(str(100 * float(np.sum(bools)) / len(bools)))
                 if len(edges) > 0:
                     logging.warning(
-                        '{}% of the {} posterior is railing on the {} edges'
-                        .format('% & '.join(fracs), k, ' & '.join(edges)))
+                        "{}% of the {} posterior is railing on the {} edges".format(
+                            "% & ".join(fracs), k, " & ".join(edges)
+                        )
+                    )
                     return_flag = True
         return return_flag
 
-    def write_par(self, method='med'):
+    def write_par(self, method="med"):
         """ Writes a .par of the best-fit params with an estimated std """
-        logging.info('Writing {}/{}.par using the {} method'.format(
-            self.outdir, self.label, method))
+        logging.info(
+            "Writing {}/{}.par using the {} method".format(
+                self.outdir, self.label, method
+            )
+        )
 
         median_std_d = self.get_median_stds()
         max_twoF_d, max_twoF = self.get_max_twoF()
 
-        logging.info('Writing par file with max twoF = {}'.format(max_twoF))
-        filename = '{}/{}.par'.format(self.outdir, self.label)
-        with open(filename, 'w+') as f:
-            f.write('MaxtwoF = {}\n'.format(max_twoF))
-            f.write('tref = {}\n'.format(self.tref))
-            if hasattr(self, 'theta0_index'):
-                f.write('theta0_index = {}\n'.format(self.theta0_idx))
-            if method == 'med':
+        logging.info("Writing par file with max twoF = {}".format(max_twoF))
+        filename = "{}/{}.par".format(self.outdir, self.label)
+        with open(filename, "w+") as f:
+            f.write("MaxtwoF = {}\n".format(max_twoF))
+            f.write("tref = {}\n".format(self.tref))
+            if hasattr(self, "theta0_index"):
+                f.write("theta0_index = {}\n".format(self.theta0_idx))
+            if method == "med":
                 for key, val in median_std_d.items():
-                    f.write('{} = {:1.16e}\n'.format(key, val))
-            if method == 'twoFmax':
+                    f.write("{} = {:1.16e}\n".format(key, val))
+            if method == "twoFmax":
                 for key, val in max_twoF_d.items():
-                    f.write('{} = {:1.16e}\n'.format(key, val))
+                    f.write("{} = {:1.16e}\n".format(key, val))
 
     def generate_loudest(self):
         """ Use lalapps_ComputeFstatistic_v2 to produce a .loudest file """
         self.write_par()
         params = read_par(label=self.label, outdir=self.outdir)
-        for key in ['Alpha', 'Delta', 'F0', 'F1']:
+        for key in ["Alpha", "Delta", "F0", "F1"]:
             if key not in params:
                 params[key] = self.theta_prior[key]
-        cmd = ('lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
-               ' --refTime={} --outputLoudest="{}/{}.loudest" '
-               '--minStartTime={} --maxStartTime={}').format(
-                    params['Alpha'], params['Delta'], params['F0'],
-                    params['F1'], self.sftfilepattern, params['tref'],
-                    self.outdir, self.label, self.minStartTime,
-                    self.maxStartTime)
+        cmd = (
+            'lalapps_ComputeFstatistic_v2 -a {} -d {} -f {} -s {} -D "{}"'
+            ' --refTime={} --outputLoudest="{}/{}.loudest" '
+            "--minStartTime={} --maxStartTime={}"
+        ).format(
+            params["Alpha"],
+            params["Delta"],
+            params["F0"],
+            params["F1"],
+            self.sftfilepattern,
+            params["tref"],
+            self.outdir,
+            self.label,
+            self.minStartTime,
+            self.maxStartTime,
+        )
         subprocess.call([cmd], shell=True)
 
     def write_prior_table(self):
         """ Generate a .tex file of the prior """
-        with open('{}/{}_prior.tex'.format(self.outdir, self.label), 'w') as f:
-            f.write(r"\begin{tabular}{c l c} \hline" + '\n'
-                    r"Parameter & & &  \\ \hhline{====}")
+        with open("{}/{}_prior.tex".format(self.outdir, self.label), "w") as f:
+            f.write(
+                r"\begin{tabular}{c l c} \hline" + "\n"
+                r"Parameter & & &  \\ \hhline{====}"
+            )
 
             for key, prior in self.theta_prior.items():
                 if type(prior) is dict:
-                    Type = prior['type']
+                    Type = prior["type"]
                     if Type == "unif":
-                        a = prior['lower']
-                        b = prior['upper']
+                        a = prior["lower"]
+                        b = prior["upper"]
                         line = r"{} & $\mathrm{{Unif}}$({}, {}) & {}\\"
                     elif Type == "norm":
-                        a = prior['loc']
-                        b = prior['scale']
+                        a = prior["loc"]
+                        b = prior["scale"]
                         line = r"{} & $\mathcal{{N}}$({}, {}) & {}\\"
                     elif Type == "halfnorm":
-                        a = prior['loc']
-                        b = prior['scale']
+                        a = prior["loc"]
+                        b = prior["scale"]
                         line = r"{} & $|\mathcal{{N}}$({}, {})| & {}\\"
 
                     u = self.unit_dictionary[key]
@@ -1542,37 +1727,56 @@ class MCMCSearch(core.BaseSearchClass):
         """ Prints a summary of the max twoF found to the terminal """
         max_twoFd, max_twoF = self.get_max_twoF()
         median_std_d = self.get_median_stds()
-        logging.info('Summary:')
-        if hasattr(self, 'theta0_idx'):
-            logging.info('theta0 index: {}'.format(self.theta0_idx))
-        logging.info('Max twoF: {} with parameters:'.format(max_twoF))
+        logging.info("Summary:")
+        if hasattr(self, "theta0_idx"):
+            logging.info("theta0 index: {}".format(self.theta0_idx))
+        logging.info("Max twoF: {} with parameters:".format(max_twoF))
         for k in np.sort(list(max_twoFd.keys())):
-            print('  {:10s} = {:1.9e}'.format(k, max_twoFd[k]))
-        logging.info('Median +/- std for production values')
+            print("  {:10s} = {:1.9e}".format(k, max_twoFd[k]))
+        logging.info("Median +/- std for production values")
         for k in np.sort(list(median_std_d.keys())):
-            if 'std' not in k:
-                logging.info('  {:10s} = {:1.9e} +/- {:1.9e}'.format(
-                    k, median_std_d[k], median_std_d[k+'_std']))
-        logging.info('\n')
+            if "std" not in k:
+                logging.info(
+                    "  {:10s} = {:1.9e} +/- {:1.9e}".format(
+                        k, median_std_d[k], median_std_d[k + "_std"]
+                    )
+                )
+        logging.info("\n")
 
     def _CF_twoFmax(self, theta, twoFmax, ntrials):
-        Fmax = twoFmax/2.0
-        return (np.exp(1j*theta*twoFmax)*ntrials/2.0
-                * Fmax*np.exp(-Fmax)*(1-(1+Fmax)*np.exp(-Fmax))**(ntrials-1))
+        Fmax = twoFmax / 2.0
+        return (
+            np.exp(1j * theta * twoFmax)
+            * ntrials
+            / 2.0
+            * Fmax
+            * np.exp(-Fmax)
+            * (1 - (1 + Fmax) * np.exp(-Fmax)) ** (ntrials - 1)
+        )
 
     def _pdf_twoFhat(self, twoFhat, nglitch, ntrials, twoFmax=100, dtwoF=0.1):
         if np.ndim(ntrials) == 0:
-            ntrials = np.zeros(nglitch+1) + ntrials
+            ntrials = np.zeros(nglitch + 1) + ntrials
         twoFmax_int = np.arange(0, twoFmax, dtwoF)
-        theta_int = np.arange(-1/dtwoF, 1./dtwoF, 1./twoFmax)
+        theta_int = np.arange(-1 / dtwoF, 1.0 / dtwoF, 1.0 / twoFmax)
         CF_twoFmax_theta = np.array(
-            [[np.trapz(self._CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int)
-              for t in theta_int]
-             for ntrial in ntrials])
+            [
+                [
+                    np.trapz(self._CF_twoFmax(t, twoFmax_int, ntrial), twoFmax_int)
+                    for t in theta_int
+                ]
+                for ntrial in ntrials
+            ]
+        )
         CF_twoFhat_theta = np.prod(CF_twoFmax_theta, axis=0)
-        pdf = (1/(2*np.pi)) * np.array(
-            [np.trapz(np.exp(-1j*theta_int*twoFhat_val)
-             * CF_twoFhat_theta, theta_int) for twoFhat_val in twoFhat])
+        pdf = (1 / (2 * np.pi)) * np.array(
+            [
+                np.trapz(
+                    np.exp(-1j * theta_int * twoFhat_val) * CF_twoFhat_theta, theta_int
+                )
+                for twoFhat_val in twoFhat
+            ]
+        )
         return pdf.real
 
     def _p_val_twoFhat(self, twoFhat, ntrials, twoFhatmax=500, Npoints=1000):
@@ -1593,15 +1797,14 @@ class MCMCSearch(core.BaseSearchClass):
         """ Get's the p-value for the maximum twoFhat value """
         d, max_twoF = self.get_max_twoF()
         if self.nglitch == 1:
-            tglitches = [d['tglitch']]
+            tglitches = [d["tglitch"]]
         else:
-            tglitches = [d['tglitch_{}'.format(i)]
-                         for i in range(self.nglitch)]
+            tglitches = [d["tglitch_{}".format(i)] for i in range(self.nglitch)]
         tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
         deltaTs = np.diff(tboundaries)
         ntrials = [time_trials + delta_F0 * dT for dT in deltaTs]
         p_val = self._p_val_twoFhat(max_twoF, ntrials)
-        print('p-value = {}'.format(p_val))
+        print("p-value = {}".format(p_val))
         return p_val
 
     def compute_evidence(self, make_plots=False, write_to_file=None):
@@ -1613,20 +1816,24 @@ class MCMCSearch(core.BaseSearchClass):
         betas = betas[::-1]
 
         if any(np.isinf(mean_lnlikes)):
-            print("WARNING mean_lnlikes contains inf: recalculating without"
-                  " the {} infs".format(len(betas[np.isinf(mean_lnlikes)])))
+            print(
+                "WARNING mean_lnlikes contains inf: recalculating without"
+                " the {} infs".format(len(betas[np.isinf(mean_lnlikes)]))
+            )
             idxs = np.isinf(mean_lnlikes)
             mean_lnlikes = mean_lnlikes[~idxs]
             betas = betas[~idxs]
 
-        log10evidence = np.trapz(mean_lnlikes, betas)/np.log(10)
+        log10evidence = np.trapz(mean_lnlikes, betas) / np.log(10)
         z1 = np.trapz(mean_lnlikes, betas)
-        z2 = np.trapz(mean_lnlikes[::-1][::2][::-1],
-                      betas[::-1][::2][::-1])
+        z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1])
         log10evidence_err = np.abs(z1 - z2) / np.log(10)
 
-        logging.info("log10 evidence for {} = {} +/- {}".format(
-              self.label, log10evidence, log10evidence_err))
+        logging.info(
+            "log10 evidence for {} = {} +/- {}".format(
+                self.label, log10evidence, log10evidence_err
+            )
+        )
 
         if write_to_file:
             EvidenceDict = self.read_evidence_file_to_dict(write_to_file)
@@ -1640,15 +1847,17 @@ class MCMCSearch(core.BaseSearchClass):
             ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
             min_betas = []
             evidence = []
-            for i in range(int(len(betas)/2.)):
+            for i in range(int(len(betas) / 2.0)):
                 min_betas.append(betas[i])
                 lnZ = np.trapz(mean_lnlikes[i:], betas[i:])
-                evidence.append(lnZ/np.log(10))
+                evidence.append(lnZ / np.log(10))
 
             ax2.semilogx(min_betas, evidence, "-o")
-            ax2.set_ylabel(r"$\int_{\beta_{\textrm{Min}}}^{\beta=1}" +
-                           r"\langle \log(\mathcal{L})\rangle d\beta$",
-                           size=16)
+            ax2.set_ylabel(
+                r"$\int_{\beta_{\textrm{Min}}}^{\beta=1}"
+                + r"\langle \log(\mathcal{L})\rangle d\beta$",
+                size=16,
+            )
             ax2.set_xlabel(r"$\beta_{\textrm{min}}$")
             plt.tight_layout()
             fig.savefig("{}/{}_beta_lnl.png".format(self.outdir, self.label))
@@ -1656,20 +1865,19 @@ class MCMCSearch(core.BaseSearchClass):
         return log10evidence, log10evidence_err
 
     @staticmethod
-    def read_evidence_file_to_dict(evidence_file_name='Evidences.txt'):
+    def read_evidence_file_to_dict(evidence_file_name="Evidences.txt"):
         EvidenceDict = OrderedDict()
         if os.path.isfile(evidence_file_name):
-            with open(evidence_file_name, 'r') as f:
+            with open(evidence_file_name, "r") as f:
                 for line in f:
-                    key, log10evidence, log10evidence_err = line.split(' ')
-                    EvidenceDict[key] = [
-                        float(log10evidence), float(log10evidence_err)]
+                    key, log10evidence, log10evidence_err = line.split(" ")
+                    EvidenceDict[key] = [float(log10evidence), float(log10evidence_err)]
         return EvidenceDict
 
     def write_evidence_file_from_dict(self, EvidenceDict, evidence_file_name):
-        with open(evidence_file_name, 'w+') as f:
+        with open(evidence_file_name, "w+") as f:
             for key, val in EvidenceDict.items():
-                f.write('{} {} {}\n'.format(key, val[0], val[1]))
+                f.write("{} {} {}\n".format(key, val[0], val[1]))
 
 
 class MCMCGlitchSearch(MCMCSearch):
@@ -1693,37 +1901,72 @@ class MCMCGlitchSearch(MCMCSearch):
     """
 
     symbol_dictionary = dict(
-        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$', Alpha=r'$\alpha$',
-        Delta='$\delta$', delta_F0='$\delta f$',
-        delta_F1='$\delta \dot{f}$', tglitch='$t_\mathrm{glitch}$')
+        F0="$f$",
+        F1="$\dot{f}$",
+        F2="$\ddot{f}$",
+        Alpha=r"$\alpha$",
+        Delta="$\delta$",
+        delta_F0="$\delta f$",
+        delta_F1="$\delta \dot{f}$",
+        tglitch="$t_\mathrm{glitch}$",
+    )
     unit_dictionary = dict(
-        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad',
-        delta_F0='Hz', delta_F1='Hz/s', tglitch='s')
+        F0="Hz",
+        F1="Hz/s",
+        F2="Hz/s$^2$",
+        Alpha=r"rad",
+        Delta="rad",
+        delta_F0="Hz",
+        delta_F1="Hz/s",
+        tglitch="s",
+    )
     transform_dictionary = dict(
         tglitch={
-            'multiplier': 1/86400.,
-            'subtractor': 'minStartTime',
-            'unit': 'day',
-            'label': '$t^{g}_0$ \n [d]'}
-            )
+            "multiplier": 1 / 86400.0,
+            "subtractor": "minStartTime",
+            "unit": "day",
+            "label": "$t^{g}_0$ \n [d]",
+        }
+    )
 
     @helper_functions.initializer
-    def __init__(self, theta_prior, tref, label, outdir='data',
-                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
-                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
-                 log10beta_min=-5, theta_initial=None,
-                 rhohatmax=1000, binary=False, BSGL=False,
-                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
-                 injectSources=None, assumeSqrtSX=None,
-                 dtglitchmin=1*86400, theta0_idx=0, nglitch=1):
+    def __init__(
+        self,
+        theta_prior,
+        tref,
+        label,
+        outdir="data",
+        minStartTime=None,
+        maxStartTime=None,
+        sftfilepattern=None,
+        detectors=None,
+        nsteps=[100, 100],
+        nwalkers=100,
+        ntemps=1,
+        log10beta_min=-5,
+        theta_initial=None,
+        rhohatmax=1000,
+        binary=False,
+        BSGL=False,
+        SSBprec=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+        dtglitchmin=1 * 86400,
+        theta0_idx=0,
+        nglitch=1,
+    ):
 
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self._add_log_file()
-        logging.info(('Set-up MCMC glitch search with {} glitches for model {}'
-                      ' on data {}').format(self.nglitch, self.label,
-                                            self.sftfilepattern))
-        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
+        logging.info(
+            (
+                "Set-up MCMC glitch search with {} glitches for model {}" " on data {}"
+            ).format(self.nglitch, self.label, self.sftfilepattern)
+        )
+        self.pickle_path = "{}/{}_saved_data.p".format(self.outdir, self.label)
         self._unpack_input_theta()
         self.ndim = len(self.theta_keys)
         if self.log10beta_min:
@@ -1731,24 +1974,32 @@ class MCMCGlitchSearch(MCMCSearch):
         else:
             self.betas = None
         if args.clean and os.path.isfile(self.pickle_path):
-            os.rename(self.pickle_path, self.pickle_path+".old")
+            os.rename(self.pickle_path, self.pickle_path + ".old")
 
         self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
         self._log_input()
         self._set_likelihoodcoef()
 
     def _set_likelihoodcoef(self):
-        self.likelihoodcoef = (self.nglitch+1)*np.log(70./self.rhohatmax**4)
+        self.likelihoodcoef = (self.nglitch + 1) * np.log(70.0 / self.rhohatmax ** 4)
 
     def _initiate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         self.search = core.SemiCoherentGlitchSearch(
-            label=self.label, outdir=self.outdir,
-            sftfilepattern=self.sftfilepattern, tref=self.tref,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
-            detectors=self.detectors, BSGL=self.BSGL, nglitch=self.nglitch,
-            theta0_idx=self.theta0_idx, injectSources=self.injectSources)
+            label=self.label,
+            outdir=self.outdir,
+            sftfilepattern=self.sftfilepattern,
+            tref=self.tref,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
+            detectors=self.detectors,
+            BSGL=self.BSGL,
+            nglitch=self.nglitch,
+            theta0_idx=self.theta0_idx,
+            injectSources=self.injectSources,
+        )
         if self.minStartTime is None:
             self.minStartTime = self.search.minStartTime
         if self.maxStartTime is None:
@@ -1756,49 +2007,65 @@ class MCMCGlitchSearch(MCMCSearch):
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
         if self.nglitch > 1:
-            ts = ([self.minStartTime] + list(theta_vals[-self.nglitch:])
-                  + [self.maxStartTime])
+            ts = (
+                [self.minStartTime]
+                + list(theta_vals[-self.nglitch :])
+                + [self.maxStartTime]
+            )
             if np.array_equal(ts, np.sort(ts)) is False:
                 return -np.inf
             if any(np.diff(ts) < self.dtglitchmin):
                 return -np.inf
 
-        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
-             zip(theta_vals, theta_keys)]
+        H = [
+            self._generic_lnprior(**theta_prior[key])(p)
+            for p, key in zip(theta_vals, theta_keys)
+        ]
         return np.sum(H)
 
     def logl(self, theta, search):
         if self.nglitch > 1:
-            ts = ([self.minStartTime] + list(theta[-self.nglitch:])
-                  + [self.maxStartTime])
+            ts = (
+                [self.minStartTime] + list(theta[-self.nglitch :]) + [self.maxStartTime]
+            )
             if np.array_equal(ts, np.sort(ts)) is False:
                 return -np.inf
 
         for j, theta_i in enumerate(self.theta_idxs):
             self.fixed_theta[theta_i] = theta[j]
         twoF = search.get_semicoherent_nglitch_twoF(*self.fixed_theta)
-        return twoF/2.0 + self.likelihoodcoef
+        return twoF / 2.0 + self.likelihoodcoef
 
     def _unpack_input_theta(self):
-        glitch_keys = ['delta_F0', 'delta_F1', 'tglitch']
-        full_glitch_keys = list(np.array(
-            [[gk]*self.nglitch for gk in glitch_keys]).flatten())
-
-        if 'tglitch_0' in self.theta_prior:
-            full_glitch_keys[-self.nglitch:] = [
-                'tglitch_{}'.format(i) for i in range(self.nglitch)]
-            full_glitch_keys[-2*self.nglitch:-1*self.nglitch] = [
-                'delta_F1_{}'.format(i) for i in range(self.nglitch)]
-            full_glitch_keys[-4*self.nglitch:-2*self.nglitch] = [
-                'delta_F0_{}'.format(i) for i in range(self.nglitch)]
-        full_theta_keys = ['F0', 'F1', 'F2', 'Alpha', 'Delta']+full_glitch_keys
+        glitch_keys = ["delta_F0", "delta_F1", "tglitch"]
+        full_glitch_keys = list(
+            np.array([[gk] * self.nglitch for gk in glitch_keys]).flatten()
+        )
+
+        if "tglitch_0" in self.theta_prior:
+            full_glitch_keys[-self.nglitch :] = [
+                "tglitch_{}".format(i) for i in range(self.nglitch)
+            ]
+            full_glitch_keys[-2 * self.nglitch : -1 * self.nglitch] = [
+                "delta_F1_{}".format(i) for i in range(self.nglitch)
+            ]
+            full_glitch_keys[-4 * self.nglitch : -2 * self.nglitch] = [
+                "delta_F0_{}".format(i) for i in range(self.nglitch)
+            ]
+        full_theta_keys = ["F0", "F1", "F2", "Alpha", "Delta"] + full_glitch_keys
         full_theta_keys_copy = copy.copy(full_theta_keys)
 
-        glitch_symbols = ['$\delta f$', '$\delta \dot{f}$', r'$t_{glitch}$']
-        full_glitch_symbols = list(np.array(
-            [[gs]*self.nglitch for gs in glitch_symbols]).flatten())
-        full_theta_symbols = (['$f$', '$\dot{f}$', '$\ddot{f}$', r'$\alpha$',
-                               r'$\delta$'] + full_glitch_symbols)
+        glitch_symbols = ["$\delta f$", "$\delta \dot{f}$", r"$t_{glitch}$"]
+        full_glitch_symbols = list(
+            np.array([[gs] * self.nglitch for gs in glitch_symbols]).flatten()
+        )
+        full_theta_symbols = [
+            "$f$",
+            "$\dot{f}$",
+            "$\ddot{f}$",
+            r"$\alpha$",
+            r"$\delta$",
+        ] + full_glitch_symbols
         self.theta_keys = []
         fixed_theta_dict = {}
         for key, val in self.theta_prior.items():
@@ -1813,8 +2080,8 @@ class MCMCGlitchSearch(MCMCSearch):
                 fixed_theta_dict[key] = val
             else:
                 raise ValueError(
-                    'Type {} of {} in theta not recognised'.format(
-                        type(val), key))
+                    "Type {} of {} in theta not recognised".format(type(val), key)
+                )
             if key in glitch_keys:
                 for i in range(self.nglitch):
                     full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
@@ -1822,9 +2089,11 @@ class MCMCGlitchSearch(MCMCSearch):
                 full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
 
         if len(full_theta_keys_copy) > 0:
-            raise ValueError(('Input dictionary `theta` is missing the'
-                              'following keys: {}').format(
-                                  full_theta_keys_copy))
+            raise ValueError(
+                ("Input dictionary `theta` is missing the" "following keys: {}").format(
+                    full_theta_keys_copy
+                )
+            )
 
         self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
         self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
@@ -1843,20 +2112,24 @@ class MCMCGlitchSearch(MCMCSearch):
                     self.theta_idxs[i] += 1
 
     def _get_data_dictionary_to_save(self):
-        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
-                 ntemps=self.ntemps, theta_keys=self.theta_keys,
-                 theta_prior=self.theta_prior,
-                 log10beta_min=self.log10beta_min,
-                 theta0_idx=self.theta0_idx, BSGL=self.BSGL,
-                 minStartTime=self.minStartTime,
-                 maxStartTime=self.maxStartTime)
+        d = dict(
+            nsteps=self.nsteps,
+            nwalkers=self.nwalkers,
+            ntemps=self.ntemps,
+            theta_keys=self.theta_keys,
+            theta_prior=self.theta_prior,
+            log10beta_min=self.log10beta_min,
+            theta0_idx=self.theta0_idx,
+            BSGL=self.BSGL,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+        )
         return d
 
     def _apply_corrections_to_p0(self, p0):
         p0 = np.array(p0)
         if self.nglitch > 1:
-            p0[:, :, -self.nglitch:] = np.sort(p0[:, :, -self.nglitch:],
-                                               axis=2)
+            p0[:, :, -self.nglitch :] = np.sort(p0[:, :, -self.nglitch :], axis=2)
         return p0
 
     def plot_cumulative_max(self):
@@ -1868,45 +2141,55 @@ class MCMCGlitchSearch(MCMCSearch):
                 d[key] = val
 
         if self.nglitch > 1:
-            delta_F0s = [d['delta_F0_{}'.format(i)] for i in
-                         range(self.nglitch)]
+            delta_F0s = [d["delta_F0_{}".format(i)] for i in range(self.nglitch)]
             delta_F0s.insert(self.theta0_idx, 0)
             delta_F0s = np.array(delta_F0s)
-            delta_F0s[:self.theta0_idx] *= -1
-            tglitches = [d['tglitch_{}'.format(i)] for i in
-                         range(self.nglitch)]
+            delta_F0s[: self.theta0_idx] *= -1
+            tglitches = [d["tglitch_{}".format(i)] for i in range(self.nglitch)]
         elif self.nglitch == 1:
-            delta_F0s = [d['delta_F0']]
+            delta_F0s = [d["delta_F0"]]
             delta_F0s.insert(self.theta0_idx, 0)
             delta_F0s = np.array(delta_F0s)
-            delta_F0s[:self.theta0_idx] *= -1
-            tglitches = [d['tglitch']]
+            delta_F0s[: self.theta0_idx] *= -1
+            tglitches = [d["tglitch"]]
 
         tboundaries = [self.minStartTime] + tglitches + [self.maxStartTime]
 
-        for j in range(self.nglitch+1):
+        for j in range(self.nglitch + 1):
             ts = tboundaries[j]
-            te = tboundaries[j+1]
-            if (te - ts)/86400 < 5:
-                logging.info('Period too short to perform cumulative search')
+            te = tboundaries[j + 1]
+            if (te - ts) / 86400 < 5:
+                logging.info("Period too short to perform cumulative search")
                 continue
             if j < self.theta0_idx:
-                summed_deltaF0 = np.sum(delta_F0s[j:self.theta0_idx])
-                F0_j = d['F0'] - summed_deltaF0
+                summed_deltaF0 = np.sum(delta_F0s[j : self.theta0_idx])
+                F0_j = d["F0"] - summed_deltaF0
                 taus, twoFs = self.search.calculate_twoF_cumulative(
-                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
-                    Delta=d['Delta'], tstart=ts, tend=te)
+                    F0_j,
+                    F1=d["F1"],
+                    F2=d["F2"],
+                    Alpha=d["Alpha"],
+                    Delta=d["Delta"],
+                    tstart=ts,
+                    tend=te,
+                )
 
             elif j >= self.theta0_idx:
-                summed_deltaF0 = np.sum(delta_F0s[self.theta0_idx:j+1])
-                F0_j = d['F0'] + summed_deltaF0
+                summed_deltaF0 = np.sum(delta_F0s[self.theta0_idx : j + 1])
+                F0_j = d["F0"] + summed_deltaF0
                 taus, twoFs = self.search.calculate_twoF_cumulative(
-                    F0_j, F1=d['F1'], F2=d['F2'], Alpha=d['Alpha'],
-                    Delta=d['Delta'], tstart=ts, tend=te)
-            ax.plot(ts+taus, twoFs)
+                    F0_j,
+                    F1=d["F1"],
+                    F2=d["F2"],
+                    Alpha=d["Alpha"],
+                    Delta=d["Delta"],
+                    tstart=ts,
+                    tend=te,
+                )
+            ax.plot(ts + taus, twoFs)
 
-        ax.set_xlabel('GPS time')
-        fig.savefig('{}/{}_twoFcumulative.png'.format(self.outdir, self.label))
+        ax.set_xlabel("GPS time")
+        fig.savefig("{}/{}_twoFcumulative.png".format(self.outdir, self.label))
 
 
 class MCMCSemiCoherentSearch(MCMCSearch):
@@ -1969,14 +2252,31 @@ class MCMCSemiCoherentSearch(MCMCSearch):
 
     """
 
-    def __init__(self, theta_prior, tref, label, outdir='data',
-                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
-                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
-                 log10beta_min=-5, theta_initial=None,
-                 rhohatmax=1000, binary=False, BSGL=False,
-                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
-                 injectSources=None, assumeSqrtSX=None,
-                 nsegs=None):
+    def __init__(
+        self,
+        theta_prior,
+        tref,
+        label,
+        outdir="data",
+        minStartTime=None,
+        maxStartTime=None,
+        sftfilepattern=None,
+        detectors=None,
+        nsteps=[100, 100],
+        nwalkers=100,
+        ntemps=1,
+        log10beta_min=-5,
+        theta_initial=None,
+        rhohatmax=1000,
+        binary=False,
+        BSGL=False,
+        SSBprec=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+        nsegs=None,
+    ):
 
         self.theta_prior = theta_prior
         self.tref = tref
@@ -2004,10 +2304,12 @@ class MCMCSemiCoherentSearch(MCMCSearch):
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self._add_log_file()
-        logging.info(('Set-up MCMC semi-coherent search for model {} on data'
-                      '{}').format(
-            self.label, self.sftfilepattern))
-        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
+        logging.info(
+            ("Set-up MCMC semi-coherent search for model {} on data" "{}").format(
+                self.label, self.sftfilepattern
+            )
+        )
+        self.pickle_path = "{}/{}_saved_data.p".format(self.outdir, self.label)
         self._unpack_input_theta()
         self.ndim = len(self.theta_keys)
         if self.log10beta_min:
@@ -2015,53 +2317,68 @@ class MCMCSemiCoherentSearch(MCMCSearch):
         else:
             self.betas = None
         if args.clean and os.path.isfile(self.pickle_path):
-            os.rename(self.pickle_path, self.pickle_path+".old")
+            os.rename(self.pickle_path, self.pickle_path + ".old")
 
         self._log_input()
 
         if self.nsegs:
             self._set_likelihoodcoef()
         else:
-            logging.info('Value `nsegs` not yet provided')
+            logging.info("Value `nsegs` not yet provided")
 
     def _set_likelihoodcoef(self):
-        self.likelihoodcoef = self.nsegs * np.log(70./self.rhohatmax**4)
+        self.likelihoodcoef = self.nsegs * np.log(70.0 / self.rhohatmax ** 4)
 
     def _get_data_dictionary_to_save(self):
-        d = dict(nsteps=self.nsteps, nwalkers=self.nwalkers,
-                 ntemps=self.ntemps, theta_keys=self.theta_keys,
-                 theta_prior=self.theta_prior,
-                 log10beta_min=self.log10beta_min,
-                 BSGL=self.BSGL, nsegs=self.nsegs,
-                 minStartTime=self.minStartTime,
-                 maxStartTime=self.maxStartTime)
+        d = dict(
+            nsteps=self.nsteps,
+            nwalkers=self.nwalkers,
+            ntemps=self.ntemps,
+            theta_keys=self.theta_keys,
+            theta_prior=self.theta_prior,
+            log10beta_min=self.log10beta_min,
+            BSGL=self.BSGL,
+            nsegs=self.nsegs,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+        )
         return d
 
     def _initiate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         self.search = core.SemiCoherentSearch(
-            label=self.label, outdir=self.outdir, tref=self.tref,
-            nsegs=self.nsegs, sftfilepattern=self.sftfilepattern,
-            binary=self.binary, BSGL=self.BSGL, minStartTime=self.minStartTime,
-            maxStartTime=self.maxStartTime, minCoverFreq=self.minCoverFreq,
-            maxCoverFreq=self.maxCoverFreq, detectors=self.detectors,
-            injectSources=self.injectSources, assumeSqrtSX=self.assumeSqrtSX)
+            label=self.label,
+            outdir=self.outdir,
+            tref=self.tref,
+            nsegs=self.nsegs,
+            sftfilepattern=self.sftfilepattern,
+            binary=self.binary,
+            BSGL=self.BSGL,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
+            detectors=self.detectors,
+            injectSources=self.injectSources,
+            assumeSqrtSX=self.assumeSqrtSX,
+        )
         if self.minStartTime is None:
             self.minStartTime = self.search.minStartTime
         if self.maxStartTime is None:
             self.maxStartTime = self.search.maxStartTime
 
     def logp(self, theta_vals, theta_prior, theta_keys, search):
-        H = [self._generic_lnprior(**theta_prior[key])(p) for p, key in
-             zip(theta_vals, theta_keys)]
+        H = [
+            self._generic_lnprior(**theta_prior[key])(p)
+            for p, key in zip(theta_vals, theta_keys)
+        ]
         return np.sum(H)
 
     def logl(self, theta, search):
         for j, theta_i in enumerate(self.theta_idxs):
             self.fixed_theta[theta_i] = theta[j]
-        twoF = search.get_semicoherent_twoF(
-            *self.fixed_theta)
-        return twoF/2.0 + self.likelihoodcoef
+        twoF = search.get_semicoherent_twoF(*self.fixed_theta)
+        return twoF / 2.0 + self.likelihoodcoef
 
 
 class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
@@ -2135,13 +2452,30 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
 
     """
 
-    def __init__(self, theta_prior, tref, label, outdir='data',
-                 minStartTime=None, maxStartTime=None, sftfilepattern=None,
-                 detectors=None, nsteps=[100, 100], nwalkers=100, ntemps=1,
-                 log10beta_min=-5, theta_initial=None,
-                 rhohatmax=1000, binary=False, BSGL=False,
-                 SSBprec=None, minCoverFreq=None, maxCoverFreq=None,
-                 injectSources=None, assumeSqrtSX=None):
+    def __init__(
+        self,
+        theta_prior,
+        tref,
+        label,
+        outdir="data",
+        minStartTime=None,
+        maxStartTime=None,
+        sftfilepattern=None,
+        detectors=None,
+        nsteps=[100, 100],
+        nwalkers=100,
+        ntemps=1,
+        log10beta_min=-5,
+        theta_initial=None,
+        rhohatmax=1000,
+        binary=False,
+        BSGL=False,
+        SSBprec=None,
+        minCoverFreq=None,
+        maxCoverFreq=None,
+        injectSources=None,
+        assumeSqrtSX=None,
+    ):
 
         self.theta_prior = theta_prior
         self.tref = tref
@@ -2169,10 +2503,12 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         if os.path.isdir(outdir) is False:
             os.mkdir(outdir)
         self._add_log_file()
-        logging.info(('Set-up MCMC semi-coherent search for model {} on data'
-                      '{}').format(
-            self.label, self.sftfilepattern))
-        self.pickle_path = '{}/{}_saved_data.p'.format(self.outdir, self.label)
+        logging.info(
+            ("Set-up MCMC semi-coherent search for model {} on data" "{}").format(
+                self.label, self.sftfilepattern
+            )
+        )
+        self.pickle_path = "{}/{}_saved_data.p".format(self.outdir, self.label)
         self._unpack_input_theta()
         self.ndim = len(self.theta_keys)
         if self.log10beta_min:
@@ -2180,45 +2516,63 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         else:
             self.betas = None
         if args.clean and os.path.isfile(self.pickle_path):
-            os.rename(self.pickle_path, self.pickle_path+".old")
+            os.rename(self.pickle_path, self.pickle_path + ".old")
 
         self._log_input()
 
         if self.nsegs:
             self._set_likelihoodcoef()
         else:
-            logging.info('Value `nsegs` not yet provided')
+            logging.info("Value `nsegs` not yet provided")
 
     def _get_data_dictionary_to_save(self):
-        d = dict(nwalkers=self.nwalkers, ntemps=self.ntemps,
-                 theta_keys=self.theta_keys, theta_prior=self.theta_prior,
-                 log10beta_min=self.log10beta_min,
-                 BSGL=self.BSGL, minStartTime=self.minStartTime,
-                 maxStartTime=self.maxStartTime, run_setup=self.run_setup)
+        d = dict(
+            nwalkers=self.nwalkers,
+            ntemps=self.ntemps,
+            theta_keys=self.theta_keys,
+            theta_prior=self.theta_prior,
+            log10beta_min=self.log10beta_min,
+            BSGL=self.BSGL,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            run_setup=self.run_setup,
+        )
         return d
 
     def update_search_object(self):
-        logging.info('Update search object')
+        logging.info("Update search object")
         self.search.init_computefstatistic_single_point()
 
     def get_width_from_prior(self, prior, key):
-        if prior[key]['type'] == 'unif':
-            return prior[key]['upper'] - prior[key]['lower']
+        if prior[key]["type"] == "unif":
+            return prior[key]["upper"] - prior[key]["lower"]
 
     def get_mid_from_prior(self, prior, key):
-        if prior[key]['type'] == 'unif':
-            return .5*(prior[key]['upper'] + prior[key]['lower'])
+        if prior[key]["type"] == "unif":
+            return 0.5 * (prior[key]["upper"] + prior[key]["lower"])
 
     def read_setup_input_file(self, run_setup_input_file):
-        with open(run_setup_input_file, 'r+') as f:
+        with open(run_setup_input_file, "r+") as f:
             d = pickle.load(f)
         return d
 
-    def write_setup_input_file(self, run_setup_input_file, NstarMax, Nsegs0,
-                               nsegs_vals, Nstar_vals, theta_prior):
-        d = dict(NstarMax=NstarMax, Nsegs0=Nsegs0, nsegs_vals=nsegs_vals,
-                 theta_prior=theta_prior, Nstar_vals=Nstar_vals)
-        with open(run_setup_input_file, 'w+') as f:
+    def write_setup_input_file(
+        self,
+        run_setup_input_file,
+        NstarMax,
+        Nsegs0,
+        nsegs_vals,
+        Nstar_vals,
+        theta_prior,
+    ):
+        d = dict(
+            NstarMax=NstarMax,
+            Nsegs0=Nsegs0,
+            nsegs_vals=nsegs_vals,
+            theta_prior=theta_prior,
+            Nstar_vals=Nstar_vals,
+        )
+        with open(run_setup_input_file, "w+") as f:
             pickle.dump(d, f)
 
     def check_old_run_setup(self, old_setup, **kwargs):
@@ -2227,38 +2581,48 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             if all(truths):
                 return True
             else:
-                logging.info(
-                    "Old setup doesn't match one of NstarMax, Nsegs0 or prior")
+                logging.info("Old setup doesn't match one of NstarMax, Nsegs0 or prior")
         except KeyError as e:
-            logging.info(
-                'Error found when comparing with old setup: {}'.format(e))
+            logging.info("Error found when comparing with old setup: {}".format(e))
             return False
 
-    def init_run_setup(self, run_setup=None, NstarMax=1000, Nsegs0=None,
-                       log_table=True, gen_tex_table=True):
+    def init_run_setup(
+        self,
+        run_setup=None,
+        NstarMax=1000,
+        Nsegs0=None,
+        log_table=True,
+        gen_tex_table=True,
+    ):
 
         if run_setup is None and Nsegs0 is None:
             raise ValueError(
-                'You must either specify the run_setup, or Nsegs0 and NStarMax'
-                ' from which the optimal run_setup can be estimated')
+                "You must either specify the run_setup, or Nsegs0 and NStarMax"
+                " from which the optimal run_setup can be estimated"
+            )
         if run_setup is None:
-            logging.info('No run_setup provided')
+            logging.info("No run_setup provided")
 
-            run_setup_input_file = '{}/{}_run_setup.p'.format(
-                self.outdir, self.label)
+            run_setup_input_file = "{}/{}_run_setup.p".format(self.outdir, self.label)
 
             if os.path.isfile(run_setup_input_file):
-                logging.info('Checking old setup input file {}'.format(
-                    run_setup_input_file))
+                logging.info(
+                    "Checking old setup input file {}".format(run_setup_input_file)
+                )
                 old_setup = self.read_setup_input_file(run_setup_input_file)
-                if self.check_old_run_setup(old_setup, NstarMax=NstarMax,
-                                            Nsegs0=Nsegs0,
-                                            theta_prior=self.theta_prior):
+                if self.check_old_run_setup(
+                    old_setup,
+                    NstarMax=NstarMax,
+                    Nsegs0=Nsegs0,
+                    theta_prior=self.theta_prior,
+                ):
                     logging.info(
-                        'Using old setup with NstarMax={}, Nsegs0={}'.format(
-                            NstarMax, Nsegs0))
-                    nsegs_vals = old_setup['nsegs_vals']
-                    Nstar_vals = old_setup['Nstar_vals']
+                        "Using old setup with NstarMax={}, Nsegs0={}".format(
+                            NstarMax, Nsegs0
+                        )
+                    )
+                    nsegs_vals = old_setup["nsegs_vals"]
+                    Nstar_vals = old_setup["Nstar_vals"]
                     generate_setup = False
                 else:
                     generate_setup = True
@@ -2266,22 +2630,31 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
                 generate_setup = True
 
             if generate_setup:
-                nsegs_vals, Nstar_vals = (
-                        optimal_setup_functions.get_optimal_setup(
-                            NstarMax, Nsegs0, self.tref, self.minStartTime,
-                            self.maxStartTime, self.theta_prior,
-                            self.search.detector_names))
-                self.write_setup_input_file(run_setup_input_file, NstarMax,
-                                            Nsegs0, nsegs_vals, Nstar_vals,
-                                            self.theta_prior)
-
-            run_setup = [((self.nsteps[0], 0),  nsegs, False)
-                         for nsegs in nsegs_vals[:-1]]
-            run_setup.append(
-                ((self.nsteps[0], self.nsteps[1]), nsegs_vals[-1], False))
+                nsegs_vals, Nstar_vals = optimal_setup_functions.get_optimal_setup(
+                    NstarMax,
+                    Nsegs0,
+                    self.tref,
+                    self.minStartTime,
+                    self.maxStartTime,
+                    self.theta_prior,
+                    self.search.detector_names,
+                )
+                self.write_setup_input_file(
+                    run_setup_input_file,
+                    NstarMax,
+                    Nsegs0,
+                    nsegs_vals,
+                    Nstar_vals,
+                    self.theta_prior,
+                )
+
+            run_setup = [
+                ((self.nsteps[0], 0), nsegs, False) for nsegs in nsegs_vals[:-1]
+            ]
+            run_setup.append(((self.nsteps[0], self.nsteps[1]), nsegs_vals[-1], False))
 
         else:
-            logging.info('Calculating the number of templates for this setup')
+            logging.info("Calculating the number of templates for this setup")
             Nstar_vals = []
             for i, rs in enumerate(run_setup):
                 rs = list(rs)
@@ -2295,46 +2668,61 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
                     Nstar_vals.append([1, 1, 1])
                 else:
                     Nstar = optimal_setup_functions.get_Nstar_estimate(
-                        rs[1], self.tref, self.minStartTime, self.maxStartTime,
-                        self.theta_prior, self.search.detector_names)
+                        rs[1],
+                        self.tref,
+                        self.minStartTime,
+                        self.maxStartTime,
+                        self.theta_prior,
+                        self.search.detector_names,
+                    )
                     Nstar_vals.append(Nstar)
 
         if log_table:
-            logging.info('Using run-setup as follows:')
-            logging.info(
-                'Stage | nburn | nprod | nsegs | Tcoh d | resetp0 | Nstar')
+            logging.info("Using run-setup as follows:")
+            logging.info("Stage | nburn | nprod | nsegs | Tcoh d | resetp0 | Nstar")
             for i, rs in enumerate(run_setup):
                 Tcoh = (self.maxStartTime - self.minStartTime) / rs[1] / 86400
                 if Nstar_vals[i] is None:
-                    vtext = 'N/A'
+                    vtext = "N/A"
                 else:
-                    vtext = '{:0.3e}'.format(int(Nstar_vals[i]))
-                logging.info('{} | {} | {} | {} | {} | {} | {}'.format(
-                    str(i).ljust(5), str(rs[0][0]).ljust(5),
-                    str(rs[0][1]).ljust(5), str(rs[1]).ljust(5),
-                    '{:6.1f}'.format(Tcoh), str(rs[2]).ljust(7),
-                    vtext))
+                    vtext = "{:0.3e}".format(int(Nstar_vals[i]))
+                logging.info(
+                    "{} | {} | {} | {} | {} | {} | {}".format(
+                        str(i).ljust(5),
+                        str(rs[0][0]).ljust(5),
+                        str(rs[0][1]).ljust(5),
+                        str(rs[1]).ljust(5),
+                        "{:6.1f}".format(Tcoh),
+                        str(rs[2]).ljust(7),
+                        vtext,
+                    )
+                )
 
         if gen_tex_table:
-            filename = '{}/{}_run_setup.tex'.format(self.outdir, self.label)
-            with open(filename, 'w+') as f:
-                f.write(r'\begin{tabular}{c|ccc}' + '\n')
-                f.write(r'Stage & $N_\mathrm{seg}$ &'
-                        r'$T_\mathrm{coh}^{\rm days}$ &'
-                        r'$\mathcal{N}^*(\Nseg^{(\ell)}, \Delta\mathbf{\lambda}^{(0)})$ \\ \hline'
-                        '\n')
+            filename = "{}/{}_run_setup.tex".format(self.outdir, self.label)
+            with open(filename, "w+") as f:
+                f.write(r"\begin{tabular}{c|ccc}" + "\n")
+                f.write(
+                    r"Stage & $N_\mathrm{seg}$ &"
+                    r"$T_\mathrm{coh}^{\rm days}$ &"
+                    r"$\mathcal{N}^*(\Nseg^{(\ell)}, \Delta\mathbf{\lambda}^{(0)})$ \\ \hline"
+                    "\n"
+                )
                 for i, rs in enumerate(run_setup):
-                    Tcoh = float(
-                        self.maxStartTime - self.minStartTime)/rs[1]/86400
-                    line = r'{} & {} & {} & {} \\' + '\n'
+                    Tcoh = float(self.maxStartTime - self.minStartTime) / rs[1] / 86400
+                    line = r"{} & {} & {} & {} \\" + "\n"
                     if Nstar_vals[i] is None:
-                        Nstar = 'N/A'
+                        Nstar = "N/A"
                     else:
                         Nstar = Nstar_vals[i]
-                    line = line.format(i, rs[1], '{:1.1f}'.format(Tcoh),
-                                       helper_functions.texify_float(Nstar))
+                    line = line.format(
+                        i,
+                        rs[1],
+                        "{:1.1f}".format(Tcoh),
+                        helper_functions.texify_float(Nstar),
+                    )
                     f.write(line)
-                f.write(r'\end{tabular}' + '\n')
+                f.write(r"\end{tabular}" + "\n")
 
         if args.setup_only:
             logging.info("Exit as requested by setup_only flag")
@@ -2342,9 +2730,21 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         else:
             return run_setup
 
-    def run(self, run_setup=None, proposal_scale_factor=2, NstarMax=10,
-            Nsegs0=None, create_plots=True, log_table=True, gen_tex_table=True,
-            fig=None, axes=None, return_fig=False, window=50, **kwargs):
+    def run(
+        self,
+        run_setup=None,
+        proposal_scale_factor=2,
+        NstarMax=10,
+        Nsegs0=None,
+        create_plots=True,
+        log_table=True,
+        gen_tex_table=True,
+        fig=None,
+        axes=None,
+        return_fig=False,
+        window=50,
+        **kwargs
+    ):
         """ Run the follow-up with the given run_setup
 
         Parameters
@@ -2370,21 +2770,24 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         self._set_likelihoodcoef()
         self._initiate_search_object()
         run_setup = self.init_run_setup(
-            run_setup, NstarMax=NstarMax, Nsegs0=Nsegs0, log_table=log_table,
-            gen_tex_table=gen_tex_table)
+            run_setup,
+            NstarMax=NstarMax,
+            Nsegs0=Nsegs0,
+            log_table=log_table,
+            gen_tex_table=gen_tex_table,
+        )
         self.run_setup = run_setup
         self._estimate_run_time()
 
         self.old_data_is_okay_to_use = self._check_old_data_is_okay_to_use()
         if self.old_data_is_okay_to_use is True:
-            logging.warning('Using saved data from {}'.format(
-                self.pickle_path))
+            logging.warning("Using saved data from {}".format(self.pickle_path))
             d = self.get_saved_data_dictionary()
-            self.samples = d['samples']
-            self.lnprobs = d['lnprobs']
-            self.lnlikes = d['lnlikes']
-            self.all_lnlikelihood = d['all_lnlikelihood']
-            self.chain = d['chain']
+            self.samples = d["samples"]
+            self.lnprobs = d["lnprobs"]
+            self.lnlikes = d["lnlikes"]
+            self.all_lnlikelihood = d["all_lnlikelihood"]
+            self.chain = d["chain"]
             self.nsegs = run_setup[-1][1]
             return
 
@@ -2406,39 +2809,56 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
             self.update_search_object()
             self.search.init_semicoherent_parameters()
             sampler = PTSampler(
-                ntemps=self.ntemps, nwalkers=self.nwalkers, dim=self.ndim,
-                logl=self.logl, logp=self.logp,
+                ntemps=self.ntemps,
+                nwalkers=self.nwalkers,
+                dim=self.ndim,
+                logl=self.logl,
+                logp=self.logp,
                 logpargs=(self.theta_prior, self.theta_keys, self.search),
-                loglargs=(self.search,), betas=self.betas,
-                a=proposal_scale_factor)
-
-            Tcoh = (self.maxStartTime-self.minStartTime)/nseg/86400.
-            logging.info(('Running {}/{} with {} steps and {} nsegs '
-                          '(Tcoh={:1.2f} days)').format(
-                j+1, len(run_setup), (nburn, nprod), nseg, Tcoh))
-            sampler = self._run_sampler(sampler, p0, nburn=nburn, nprod=nprod,
-                                        window=window)
-            logging.info('Max detection statistic of run was {}'.format(
-                np.max(sampler.loglikelihood)))
+                loglargs=(self.search,),
+                betas=self.betas,
+                a=proposal_scale_factor,
+            )
+
+            Tcoh = (self.maxStartTime - self.minStartTime) / nseg / 86400.0
+            logging.info(
+                (
+                    "Running {}/{} with {} steps and {} nsegs " "(Tcoh={:1.2f} days)"
+                ).format(j + 1, len(run_setup), (nburn, nprod), nseg, Tcoh)
+            )
+            sampler = self._run_sampler(
+                sampler, p0, nburn=nburn, nprod=nprod, window=window
+            )
+            logging.info(
+                "Max detection statistic of run was {}".format(
+                    np.max(sampler.loglikelihood)
+                )
+            )
 
             if create_plots:
                 fig, axes = self._plot_walkers(
-                    sampler, fig=fig, axes=axes,
-                    nprod=nprod, xoffset=nsteps_total, **kwargs)
-                for ax in axes[:self.ndim]:
-                    ax.axvline(nsteps_total, color='k', ls='--', lw=0.25)
-
-            nsteps_total += nburn+nprod
+                    sampler,
+                    fig=fig,
+                    axes=axes,
+                    nprod=nprod,
+                    xoffset=nsteps_total,
+                    **kwargs
+                )
+                for ax in axes[: self.ndim]:
+                    ax.axvline(nsteps_total, color="k", ls="--", lw=0.25)
+
+            nsteps_total += nburn + nprod
 
         if create_plots:
             nstep_list = np.array(
-                [el[0][0] for el in run_setup] + [run_setup[-1][0][1]])
-            mids = np.cumsum(nstep_list) - nstep_list/2
-            mid_labels = ['{:1.0f}'.format(i) for i in np.arange(0, len(mids)-1)]
-            mid_labels += ['Production']
-            for ax in axes[:self.ndim]:
+                [el[0][0] for el in run_setup] + [run_setup[-1][0][1]]
+            )
+            mids = np.cumsum(nstep_list) - nstep_list / 2
+            mid_labels = ["{:1.0f}".format(i) for i in np.arange(0, len(mids) - 1)]
+            mid_labels += ["Production"]
+            for ax in axes[: self.ndim]:
                 axy = ax.twiny()
-                axy.tick_params(pad=-10, direction='in', axis='x', which='major')
+                axy.tick_params(pad=-10, direction="in", axis="x", which="major")
                 axy.minorticks_off()
                 axy.set_xlim(ax.get_xlim())
                 axy.set_xticks(mids)
@@ -2452,19 +2872,19 @@ class MCMCFollowUpSearch(MCMCSemiCoherentSearch):
         self.lnprobs = lnprobs
         self.lnlikes = lnlikes
         self.all_lnlikelihood = all_lnlikelihood
-        self._save_data(sampler, samples, lnprobs, lnlikes, all_lnlikelihood,
-                        sampler.chain)
+        self._save_data(
+            sampler, samples, lnprobs, lnlikes, all_lnlikelihood, sampler.chain
+        )
 
         if create_plots:
             try:
                 fig.tight_layout()
             except (ValueError, RuntimeError) as e:
-                logging.warning('Tight layout encountered {}'.format(e))
+                logging.warning("Tight layout encountered {}".format(e))
             if return_fig:
                 return fig, axes
             else:
-                fig.savefig('{}/{}_walkers.png'.format(
-                    self.outdir, self.label))
+                fig.savefig("{}/{}_walkers.png".format(self.outdir, self.label))
 
 
 class MCMCTransientSearch(MCMCSearch):
@@ -2476,37 +2896,56 @@ class MCMCTransientSearch(MCMCSearch):
     """
 
     symbol_dictionary = dict(
-        F0='$f$', F1='$\dot{f}$', F2='$\ddot{f}$',
-        Alpha=r'$\alpha$', Delta='$\delta$',
-        transient_tstart='$t_\mathrm{start}$', transient_duration='$\Delta T$')
+        F0="$f$",
+        F1="$\dot{f}$",
+        F2="$\ddot{f}$",
+        Alpha=r"$\alpha$",
+        Delta="$\delta$",
+        transient_tstart="$t_\mathrm{start}$",
+        transient_duration="$\Delta T$",
+    )
     unit_dictionary = dict(
-        F0='Hz', F1='Hz/s', F2='Hz/s$^2$', Alpha=r'rad', Delta='rad',
-        transient_tstart='s', transient_duration='s')
+        F0="Hz",
+        F1="Hz/s",
+        F2="Hz/s$^2$",
+        Alpha=r"rad",
+        Delta="rad",
+        transient_tstart="s",
+        transient_duration="s",
+    )
 
     transform_dictionary = dict(
-        transient_duration={'multiplier': 1/86400.,
-                            'unit': 'day',
-                            'symbol': 'Transient duration'},
+        transient_duration={
+            "multiplier": 1 / 86400.0,
+            "unit": "day",
+            "symbol": "Transient duration",
+        },
         transient_tstart={
-            'multiplier': 1/86400.,
-            'subtractor': 'minStartTime',
-            'unit': 'day',
-            'label': 'Transient start-time \n days after minStartTime'}
-            )
+            "multiplier": 1 / 86400.0,
+            "subtractor": "minStartTime",
+            "unit": "day",
+            "label": "Transient start-time \n days after minStartTime",
+        },
+    )
 
     def _initiate_search_object(self):
-        logging.info('Setting up search object')
+        logging.info("Setting up search object")
         if not self.transientWindowType:
-            self.transientWindowType = 'rect'
+            self.transientWindowType = "rect"
         self.search = core.ComputeFstat(
-            tref=self.tref, sftfilepattern=self.sftfilepattern,
-            minCoverFreq=self.minCoverFreq, maxCoverFreq=self.maxCoverFreq,
+            tref=self.tref,
+            sftfilepattern=self.sftfilepattern,
+            minCoverFreq=self.minCoverFreq,
+            maxCoverFreq=self.maxCoverFreq,
             detectors=self.detectors,
             transientWindowType=self.transientWindowType,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime,
-            BSGL=self.BSGL, binary=self.binary,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+            BSGL=self.BSGL,
+            binary=self.binary,
             injectSources=self.injectSources,
-            tCWFstatMapVersion=self.tCWFstatMapVersion)
+            tCWFstatMapVersion=self.tCWFstatMapVersion,
+        )
         if self.minStartTime is None:
             self.minStartTime = self.search.minStartTime
         if self.maxStartTime is None:
@@ -2520,23 +2959,33 @@ class MCMCTransientSearch(MCMCSearch):
         if in_theta[1] > self.maxStartTime:
             return -np.inf
         twoF = search.get_fullycoherent_twoF(*in_theta)
-        return twoF/2.0 + self.likelihoodcoef
+        return twoF / 2.0 + self.likelihoodcoef
 
     def _unpack_input_theta(self):
-        full_theta_keys = ['transient_tstart',
-                           'transient_duration', 'F0', 'F1', 'F2', 'Alpha',
-                           'Delta']
+        full_theta_keys = [
+            "transient_tstart",
+            "transient_duration",
+            "F0",
+            "F1",
+            "F2",
+            "Alpha",
+            "Delta",
+        ]
         if self.binary:
-            full_theta_keys += [
-                'asini', 'period', 'ecc', 'tp', 'argp']
+            full_theta_keys += ["asini", "period", "ecc", "tp", "argp"]
         full_theta_keys_copy = copy.copy(full_theta_keys)
 
-        full_theta_symbols = [r'$t_{\rm start}$', r'$\Delta T$',
-                              '$f$', '$\dot{f}$', '$\ddot{f}$',
-                              r'$\alpha$', r'$\delta$']
+        full_theta_symbols = [
+            r"$t_{\rm start}$",
+            r"$\Delta T$",
+            "$f$",
+            "$\dot{f}$",
+            "$\ddot{f}$",
+            r"$\alpha$",
+            r"$\delta$",
+        ]
         if self.binary:
-            full_theta_symbols += [
-                'asini', 'period', 'period', 'ecc', 'tp', 'argp']
+            full_theta_symbols += ["asini", "period", "period", "ecc", "tp", "argp"]
 
         self.theta_keys = []
         fixed_theta_dict = {}
@@ -2548,14 +2997,16 @@ class MCMCTransientSearch(MCMCSearch):
                 fixed_theta_dict[key] = val
             else:
                 raise ValueError(
-                    'Type {} of {} in theta not recognised'.format(
-                        type(val), key))
+                    "Type {} of {} in theta not recognised".format(type(val), key)
+                )
             full_theta_keys_copy.pop(full_theta_keys_copy.index(key))
 
         if len(full_theta_keys_copy) > 0:
-            raise ValueError(('Input dictionary `theta` is missing the'
-                              'following keys: {}').format(
-                                  full_theta_keys_copy))
+            raise ValueError(
+                ("Input dictionary `theta` is missing the" "following keys: {}").format(
+                    full_theta_keys_copy
+                )
+            )
 
         self.fixed_theta = [fixed_theta_dict[key] for key in full_theta_keys]
         self.theta_idxs = [full_theta_keys.index(k) for k in self.theta_keys]
diff --git a/pyfstat/optimal_setup_functions.py b/pyfstat/optimal_setup_functions.py
index 5481bf6aff2570c5581c14697723822f200efff5..43fd89ba0fe669f45206de09f774a23d87ddee59 100644
--- a/pyfstat/optimal_setup_functions.py
+++ b/pyfstat/optimal_setup_functions.py
@@ -14,8 +14,8 @@ import pyfstat.helper_functions as helper_functions
 
 
 def get_optimal_setup(
-        NstarMax, Nsegs0, tref, minStartTime, maxStartTime, prior,
-        detector_names):
+    NstarMax, Nsegs0, tref, minStartTime, maxStartTime, prior, detector_names
+):
     """ Using an optimisation step, calculate the optimal setup ladder
 
     Parameters
@@ -37,14 +37,14 @@ def get_optimal_setup(
 
     """
 
-    logging.info('Calculating optimal setup for NstarMax={}, Nsegs0={}'.format(
-        NstarMax, Nsegs0))
+    logging.info(
+        "Calculating optimal setup for NstarMax={}, Nsegs0={}".format(NstarMax, Nsegs0)
+    )
 
     Nstar_0 = get_Nstar_estimate(
-        Nsegs0, tref, minStartTime, maxStartTime, prior,
-        detector_names)
-    logging.info(
-        'Stage {}, nsegs={}, Nstar={}'.format(0, Nsegs0, int(Nstar_0)))
+        Nsegs0, tref, minStartTime, maxStartTime, prior, detector_names
+    )
+    logging.info("Stage {}, nsegs={}, Nstar={}".format(0, Nsegs0, int(Nstar_0)))
 
     nsegs_vals = [Nsegs0]
     Nstar_vals = [Nstar_0]
@@ -53,25 +53,27 @@ def get_optimal_setup(
     nsegs_i = Nsegs0
     while nsegs_i > 1:
         nsegs_i, Nstar_i = _get_nsegs_ip1(
-            nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior,
-            detector_names)
+            nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, detector_names
+        )
         nsegs_vals.append(nsegs_i)
         Nstar_vals.append(Nstar_i)
         i += 1
-        logging.info(
-            'Stage {}, nsegs={}, Nstar={}'.format(i, nsegs_i, int(Nstar_i)))
+        logging.info("Stage {}, nsegs={}, Nstar={}".format(i, nsegs_i, int(Nstar_i)))
 
     return nsegs_vals, Nstar_vals
 
 
-def _get_nsegs_ip1(nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior,
-                   detector_names):
+def _get_nsegs_ip1(
+    nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior, detector_names
+):
     """ Calculate Nsegs_{i+1} given Nsegs_{i} """
 
     log10NstarMax = np.log10(NstarMax)
-    log10Nstari = np.log10(get_Nstar_estimate(
-        nsegs_i, tref, minStartTime, maxStartTime, prior,
-        detector_names))
+    log10Nstari = np.log10(
+        get_Nstar_estimate(
+            nsegs_i, tref, minStartTime, maxStartTime, prior, detector_names
+        )
+    )
 
     def f(nsegs_ip1):
         if nsegs_ip1[0] > nsegs_i:
@@ -82,24 +84,30 @@ def _get_nsegs_ip1(nsegs_i, NstarMax, tref, minStartTime, maxStartTime, prior,
         if nsegs_ip1 == 0:
             nsegs_ip1 = 1
         Nstarip1 = get_Nstar_estimate(
-            nsegs_ip1, tref, minStartTime, maxStartTime, prior, detector_names)
+            nsegs_ip1, tref, minStartTime, maxStartTime, prior, detector_names
+        )
         if Nstarip1 is None:
             return 1e6
         else:
             log10Nstarip1 = np.log10(Nstarip1)
             return np.abs(log10Nstari + log10NstarMax - log10Nstarip1)
-    res = scipy.optimize.minimize(f, .4*nsegs_i, method='Powell', tol=1,
-                                  options={'maxiter': 10})
-    logging.info('{} with {} evaluations'.format(res['message'], res['nfev']))
+
+    res = scipy.optimize.minimize(
+        f, 0.4 * nsegs_i, method="Powell", tol=1, options={"maxiter": 10}
+    )
+    logging.info("{} with {} evaluations".format(res["message"], res["nfev"]))
     nsegs_ip1 = int(res.x)
     if nsegs_ip1 == 0:
         nsegs_ip1 = 1
     if res.success:
-        return nsegs_ip1, get_Nstar_estimate(
-            nsegs_ip1, tref, minStartTime, maxStartTime, prior,
-            detector_names)
+        return (
+            nsegs_ip1,
+            get_Nstar_estimate(
+                nsegs_ip1, tref, minStartTime, maxStartTime, prior, detector_names
+            ),
+        )
     else:
-        raise ValueError('Optimisation unsuccesful')
+        raise ValueError("Optimisation unsuccesful")
 
 
 def _extract_data_from_prior(prior):
@@ -121,7 +129,7 @@ def _extract_data_from_prior(prior):
         Fidicual frequency
 
     """
-    keys = ['Alpha', 'Delta', 'F0', 'F1', 'F2']
+    keys = ["Alpha", "Delta", "F0", "F1", "F2"]
     spindown_keys = keys[3:]
     sky_keys = keys[:2]
     lims = []
@@ -129,14 +137,14 @@ def _extract_data_from_prior(prior):
     lims_idxs = []
     for i, key in enumerate(keys):
         if type(prior[key]) == dict:
-            if prior[key]['type'] == 'unif':
-                lims.append([prior[key]['lower'], prior[key]['upper']])
+            if prior[key]["type"] == "unif":
+                lims.append([prior[key]["lower"], prior[key]["upper"]])
                 lims_keys.append(key)
                 lims_idxs.append(i)
             else:
                 raise ValueError(
-                    "Prior type {} not yet supported".format(
-                        prior[key]['type']))
+                    "Prior type {} not yet supported".format(prior[key]["type"])
+                )
         elif key not in spindown_keys:
             lims.append([prior[key], 0])
     lims = np.array(lims)
@@ -149,16 +157,15 @@ def _extract_data_from_prior(prior):
         p.append(basex)
     spindowns = np.sum([np.sum(lims_keys == k) for k in spindown_keys])
     sky = any([key in lims_keys for key in sky_keys])
-    if type(prior['F0']) == dict:
-        fiducial_freq = prior['F0']['upper']
+    if type(prior["F0"]) == dict:
+        fiducial_freq = prior["F0"]["upper"]
     else:
-        fiducial_freq = prior['F0']
+        fiducial_freq = prior["F0"]
 
     return np.array(p).T, spindowns, sky, fiducial_freq
 
 
-def get_Nstar_estimate(
-        nsegs, tref, minStartTime, maxStartTime, prior, detector_names):
+def get_Nstar_estimate(nsegs, tref, minStartTime, maxStartTime, prior, detector_names):
     """ Returns N* estimated from the super-sky metric
 
     Parameters
@@ -189,29 +196,35 @@ def get_Nstar_estimate(
     in_phys = helper_functions.convert_array_to_gsl_matrix(in_phys)
     out_rssky = helper_functions.convert_array_to_gsl_matrix(out_rssky)
 
-    tboundaries = np.linspace(minStartTime, maxStartTime, nsegs+1)
+    tboundaries = np.linspace(minStartTime, maxStartTime, nsegs + 1)
 
     ref_time = lal.LIGOTimeGPS(tref)
     segments = lal.SegListCreate()
-    for j in range(len(tboundaries)-1):
-        seg = lal.SegCreate(lal.LIGOTimeGPS(tboundaries[j]),
-                            lal.LIGOTimeGPS(tboundaries[j+1]),
-                            j)
+    for j in range(len(tboundaries) - 1):
+        seg = lal.SegCreate(
+            lal.LIGOTimeGPS(tboundaries[j]), lal.LIGOTimeGPS(tboundaries[j + 1]), j
+        )
         lal.SegListAppend(segments, seg)
     detNames = lal.CreateStringVector(*detector_names)
     detectors = lalpulsar.MultiLALDetector()
     lalpulsar.ParseMultiLALDetector(detectors, detNames)
     detector_weights = None
-    detector_motion = (lalpulsar.DETMOTION_SPIN
-                       + lalpulsar.DETMOTION_ORBIT)
+    detector_motion = lalpulsar.DETMOTION_SPIN + lalpulsar.DETMOTION_ORBIT
     ephemeris = lalpulsar.InitBarycenter(earth_ephem, sun_ephem)
     try:
         SSkyMetric = lalpulsar.ComputeSuperskyMetrics(
-            lalpulsar.SUPERSKY_METRIC_TYPE, spindowns, ref_time, segments,
-            fiducial_freq, detectors, detector_weights, detector_motion,
-            ephemeris)
+            lalpulsar.SUPERSKY_METRIC_TYPE,
+            spindowns,
+            ref_time,
+            segments,
+            fiducial_freq,
+            detectors,
+            detector_weights,
+            detector_motion,
+            ephemeris,
+        )
     except RuntimeError as e:
-        logging.warning('Encountered run-time error {}'.format(e))
+        logging.warning("Encountered run-time error {}".format(e))
         raise RuntimeError("Calculation of the SSkyMetric failed")
 
     if sky:
@@ -220,7 +233,8 @@ def get_Nstar_estimate(
         i = 2
 
     lalpulsar.ConvertPhysicalToSuperskyPoints(
-        out_rssky, in_phys, SSkyMetric.semi_rssky_transf)
+        out_rssky, in_phys, SSkyMetric.semi_rssky_transf
+    )
 
     d = out_rssky.data
 
@@ -230,10 +244,13 @@ def get_Nstar_estimate(
     parallelepiped = (d[i:, 1:].T - d[i:, 0]).T
 
     Nstars = []
-    for j in range(1, len(g)+1):
+    for j in range(1, len(g) + 1):
         dV = np.abs(np.linalg.det(parallelepiped[:j, :j]))
         sqrtdetG = np.sqrt(np.abs(np.linalg.det(g[:j, :j])))
         Nstars.append(sqrtdetG * dV)
-    logging.debug('Nstar for each dimension = {}'.format(
-        ', '.join(["{:1.1e}".format(n) for n in Nstars])))
+    logging.debug(
+        "Nstar for each dimension = {}".format(
+            ", ".join(["{:1.1e}".format(n) for n in Nstars])
+        )
+    )
     return np.max(Nstars)
diff --git a/pyfstat/tcw_fstat_map_funcs.py b/pyfstat/tcw_fstat_map_funcs.py
index 2676c1a00e7ff5f038eaf6f4226595fadad85b1b..e3ef1f34c2e6e3a4149cc5a0a66c14403bfad2d3 100644
--- a/pyfstat/tcw_fstat_map_funcs.py
+++ b/pyfstat/tcw_fstat_map_funcs.py
@@ -9,8 +9,8 @@ from time import time
 import importlib as imp
 
 
-def _optional_import ( modulename, shorthand=None ):
-    '''
+def _optional_import(modulename, shorthand=None):
+    """
     Import a module/submodule only if it's available.
 
     using importlib instead of __import__
@@ -18,22 +18,21 @@ def _optional_import ( modulename, shorthand=None ):
 
     Also including a special check to fail more gracefully
     when CUDA_DEVICE is set to too high a number.
-    '''
+    """
 
     if shorthand is None:
-        shorthand    = modulename
-        shorthandbit = ''
+        shorthand = modulename
+        shorthandbit = ""
     else:
-        shorthandbit = ' as '+shorthand
+        shorthandbit = " as " + shorthand
 
     try:
         globals()[shorthand] = imp.import_module(modulename)
-        logging.debug('Successfully imported module %s%s.'
-                      % (modulename, shorthandbit))
+        logging.debug("Successfully imported module %s%s." % (modulename, shorthandbit))
         success = True
     except ImportError as e:
-        if e.message == 'No module named '+modulename:
-            logging.debug('No module {:s} found.'.format(modulename))
+        if e.message == "No module named " + modulename:
+            logging.debug("No module {:s} found.".format(modulename))
             success = False
         else:
             raise
@@ -42,7 +41,7 @@ def _optional_import ( modulename, shorthand=None ):
 
 
 class pyTransientFstatMap(object):
-    '''
+    """
     simplified object class for a F(t0,tau) F-stat map (not 2F!)
     based on LALSuite's transientFstatMap_t type
     replacing the gsl matrix with a numpy array
@@ -51,119 +50,142 @@ class pyTransientFstatMap(object):
     maxF:   maximum of F (not 2F!)
     t0_ML:  maximum likelihood transient start time t0 estimate
     tau_ML: maximum likelihood transient duration tau estimate
-    '''
+    """
 
     def __init__(self, N_t0Range, N_tauRange):
-        self.F_mn   = np.zeros((N_t0Range, N_tauRange), dtype=np.float32)
+        self.F_mn = np.zeros((N_t0Range, N_tauRange), dtype=np.float32)
         # Initializing maxF to a negative value ensures
         # that we always update at least once and hence return
         # sane t0_d_ML, tau_d_ML
         # even if there is only a single bin where F=0 happens.
-        self.maxF   = float(-1.0)
-        self.t0_ML  = float(0.0)
+        self.maxF = float(-1.0)
+        self.t0_ML = float(0.0)
         self.tau_ML = float(0.0)
 
 
 # dictionary of the actual callable F-stat map functions we support,
 # if the corresponding modules are available.
 fstatmap_versions = {
-                     'lal':    lambda multiFstatAtoms, windowRange:
-                               getattr(lalpulsar,'ComputeTransientFstatMap')
-                                ( multiFstatAtoms, windowRange, False ),
-                     'pycuda': lambda multiFstatAtoms, windowRange:
-                               pycuda_compute_transient_fstat_map
-                                ( multiFstatAtoms, windowRange )
-                    }
+    "lal": lambda multiFstatAtoms, windowRange: getattr(
+        lalpulsar, "ComputeTransientFstatMap"
+    )(multiFstatAtoms, windowRange, False),
+    "pycuda": lambda multiFstatAtoms, windowRange: pycuda_compute_transient_fstat_map(
+        multiFstatAtoms, windowRange
+    ),
+}
 
 
-def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ):
-    '''
+def init_transient_fstat_map_features(wantCuda=False, cudaDeviceName=None):
+    """
     Initialization of available modules (or "features") for F-stat maps.
 
     Returns a dictionary of method names, to match fstatmap_versions
     each key's value set to True only if
     all required modules are importable on this system.
-    '''
+    """
 
     features = {}
 
-    have_lal           = _optional_import('lal')
-    have_lalpulsar     = _optional_import('lalpulsar')
-    features['lal']    = have_lal and have_lalpulsar
+    have_lal = _optional_import("lal")
+    have_lalpulsar = _optional_import("lalpulsar")
+    features["lal"] = have_lal and have_lalpulsar
 
     # import GPU features
-    have_pycuda          = _optional_import('pycuda')
-    have_pycuda_drv      = _optional_import('pycuda.driver', 'drv')
-    have_pycuda_gpuarray = _optional_import('pycuda.gpuarray', 'gpuarray')
-    have_pycuda_tools    = _optional_import('pycuda.tools', 'cudatools')
-    have_pycuda_compiler = _optional_import('pycuda.compiler', 'cudacomp')
-    features['pycuda']   = ( have_pycuda_drv and have_pycuda_gpuarray and
-                            have_pycuda_tools and have_pycuda_compiler )
-
-    logging.debug('Got the following features for transient F-stat maps:')
+    have_pycuda = _optional_import("pycuda")
+    have_pycuda_drv = _optional_import("pycuda.driver", "drv")
+    have_pycuda_gpuarray = _optional_import("pycuda.gpuarray", "gpuarray")
+    have_pycuda_tools = _optional_import("pycuda.tools", "cudatools")
+    have_pycuda_compiler = _optional_import("pycuda.compiler", "cudacomp")
+    features["pycuda"] = (
+        have_pycuda_drv
+        and have_pycuda_gpuarray
+        and have_pycuda_tools
+        and have_pycuda_compiler
+    )
+
+    logging.debug("Got the following features for transient F-stat maps:")
     logging.debug(features)
 
-    if wantCuda and features['pycuda']:
-        logging.debug('CUDA version: '+'.'.join(map(str,drv.get_version())))
+    if wantCuda and features["pycuda"]:
+        logging.debug("CUDA version: " + ".".join(map(str, drv.get_version())))
 
         drv.init()
-        logging.debug('Starting with default pyCUDA context,' \
-                      ' then checking all available devices...')
+        logging.debug(
+            "Starting with default pyCUDA context,"
+            " then checking all available devices..."
+        )
         try:
             context0 = pycuda.tools.make_default_context()
         except pycuda._driver.LogicError as e:
-            if e.message == 'cuDeviceGet failed: invalid device ordinal':
-                devn = int(os.environ['CUDA_DEVICE'])
-                raise RuntimeError('Requested CUDA device number {} exceeds' \
-                                   ' number of available devices!' \
-                                   ' Please change through environment' \
-                                   ' variable $CUDA_DEVICE.'.format(devn))
+            if e.message == "cuDeviceGet failed: invalid device ordinal":
+                devn = int(os.environ["CUDA_DEVICE"])
+                raise RuntimeError(
+                    "Requested CUDA device number {} exceeds"
+                    " number of available devices!"
+                    " Please change through environment"
+                    " variable $CUDA_DEVICE.".format(devn)
+                )
             else:
                 raise pycuda._driver.LogicError(e.message)
 
         num_gpus = drv.Device.count()
-        logging.debug('Found {} CUDA device(s).'.format(num_gpus))
+        logging.debug("Found {} CUDA device(s).".format(num_gpus))
 
         devices = []
-        devnames = np.empty(num_gpus,dtype='S32')
+        devnames = np.empty(num_gpus, dtype="S32")
         for n in range(num_gpus):
             devn = drv.Device(n)
             devices.append(devn)
-            devnames[n] = devn.name().replace(' ','-').replace('_','-')
-            logging.debug('device {}: model: {}, RAM: {}MB'.format(
-                n, devnames[n], devn.total_memory()/(2.**20) ))
-
-        if 'CUDA_DEVICE' in os.environ:
-            devnum0 = int(os.environ['CUDA_DEVICE'])
+            devnames[n] = devn.name().replace(" ", "-").replace("_", "-")
+            logging.debug(
+                "device {}: model: {}, RAM: {}MB".format(
+                    n, devnames[n], devn.total_memory() / (2.0 ** 20)
+                )
+            )
+
+        if "CUDA_DEVICE" in os.environ:
+            devnum0 = int(os.environ["CUDA_DEVICE"])
         else:
             devnum0 = 0
 
-        matchbit = ''
+        matchbit = ""
         if cudaDeviceName:
             # allow partial matches in device names
-            devmatches = [devidx for devidx, devname in enumerate(devnames)
-                          if cudaDeviceName in devname]
+            devmatches = [
+                devidx
+                for devidx, devname in enumerate(devnames)
+                if cudaDeviceName in devname
+            ]
             if len(devmatches) == 0:
                 context0.detach()
-                raise RuntimeError('Requested CUDA device "{}" not found.' \
-                                   ' Available devices: [{}]'.format(
-                                      cudaDeviceName,','.join(devnames)))
+                raise RuntimeError(
+                    'Requested CUDA device "{}" not found.'
+                    " Available devices: [{}]".format(
+                        cudaDeviceName, ",".join(devnames)
+                    )
+                )
             else:
                 devnum = devmatches[0]
                 if len(devmatches) > 1:
-                    logging.warning('Found {} CUDA devices matching name "{}".' \
-                                    ' Choosing first one with index {}.'.format(
-                                        len(devmatches),cudaDeviceName,devnum))
-            os.environ['CUDA_DEVICE'] = str(devnum)
-            matchbit =  '(matched to user request "{}")'.format(cudaDeviceName)
-        elif 'CUDA_DEVICE' in os.environ:
-            devnum = int(os.environ['CUDA_DEVICE'])
+                    logging.warning(
+                        'Found {} CUDA devices matching name "{}".'
+                        " Choosing first one with index {}.".format(
+                            len(devmatches), cudaDeviceName, devnum
+                        )
+                    )
+            os.environ["CUDA_DEVICE"] = str(devnum)
+            matchbit = '(matched to user request "{}")'.format(cudaDeviceName)
+        elif "CUDA_DEVICE" in os.environ:
+            devnum = int(os.environ["CUDA_DEVICE"])
         else:
             devnum = 0
         devn = devices[devnum]
-        logging.info('Choosing CUDA device {},' \
-                     ' of {} devices present: {}{}...'.format(
-                         devnum, num_gpus, devn.name(), matchbit))
+        logging.info(
+            "Choosing CUDA device {},"
+            " of {} devices present: {}{}...".format(
+                devnum, num_gpus, devn.name(), matchbit
+            )
+        )
         if devnum == devnum0:
             gpu_context = context0
         else:
@@ -171,79 +193,87 @@ def init_transient_fstat_map_features ( wantCuda=False, cudaDeviceName=None ):
             gpu_context = pycuda.tools.make_default_context()
             gpu_context.push()
 
-        _print_GPU_memory_MB('Available')
+        _print_GPU_memory_MB("Available")
     else:
         gpu_context = None
 
     return features, gpu_context
 
 
-def call_compute_transient_fstat_map ( version,
-                                       features,
-                                       multiFstatAtoms=None,
-                                       windowRange=None ):
-    '''Choose which version of the ComputeTransientFstatMap function to call.'''
+def call_compute_transient_fstat_map(
+    version, features, multiFstatAtoms=None, windowRange=None
+):
+    """Choose which version of the ComputeTransientFstatMap function to call."""
 
     if version in fstatmap_versions:
         if features[version]:
             time0 = time()
             FstatMap = fstatmap_versions[version](multiFstatAtoms, windowRange)
-            timingFstatMap = time()-time0
+            timingFstatMap = time() - time0
         else:
-            raise Exception('Required module(s) for transient F-stat map' \
-                            ' method "{}" not available!'.format(version))
+            raise Exception(
+                "Required module(s) for transient F-stat map"
+                ' method "{}" not available!'.format(version)
+            )
     else:
-        raise Exception('Transient F-stat map method "{}"' \
-                        ' not implemented!'.format(version))
+        raise Exception(
+            'Transient F-stat map method "{}"' " not implemented!".format(version)
+        )
     return FstatMap, timingFstatMap
 
 
-def reshape_FstatAtomsVector ( atomsVector ):
-    '''
+def reshape_FstatAtomsVector(atomsVector):
+    """
     Make a dictionary of ndarrays out of a atoms "vector" structure.
 
     The input is a "vector"-like structure with times as the higher hierarchical
     level and a set of "atoms" quantities defined at each timestamp.
     The output is a dictionary with an entry for each quantity,
     which is a 1D ndarray over timestamps for that one quantity.
-    '''
+    """
 
     numAtoms = atomsVector.length
     atomsDict = {}
-    atom_fieldnames = ['timestamp', 'Fa_alpha', 'Fb_alpha',
-                       'a2_alpha', 'ab_alpha', 'b2_alpha']
-    atom_dtypes     = [np.uint32, complex, complex,
-                       np.float32, np.float32, np.float32]
+    atom_fieldnames = [
+        "timestamp",
+        "Fa_alpha",
+        "Fb_alpha",
+        "a2_alpha",
+        "ab_alpha",
+        "b2_alpha",
+    ]
+    atom_dtypes = [np.uint32, complex, complex, np.float32, np.float32, np.float32]
     for f, field in enumerate(atom_fieldnames):
-        atomsDict[field] = np.ndarray(numAtoms,dtype=atom_dtypes[f])
+        atomsDict[field] = np.ndarray(numAtoms, dtype=atom_dtypes[f])
 
-    for n,atom in enumerate(atomsVector.data):
+    for n, atom in enumerate(atomsVector.data):
         for field in atom_fieldnames:
             atomsDict[field][n] = atom.__getattribute__(field)
 
-    atomsDict['Fa_alpha_re'] = np.float32(atomsDict['Fa_alpha'].real)
-    atomsDict['Fa_alpha_im'] = np.float32(atomsDict['Fa_alpha'].imag)
-    atomsDict['Fb_alpha_re'] = np.float32(atomsDict['Fb_alpha'].real)
-    atomsDict['Fb_alpha_im'] = np.float32(atomsDict['Fb_alpha'].imag)
+    atomsDict["Fa_alpha_re"] = np.float32(atomsDict["Fa_alpha"].real)
+    atomsDict["Fa_alpha_im"] = np.float32(atomsDict["Fa_alpha"].imag)
+    atomsDict["Fb_alpha_re"] = np.float32(atomsDict["Fb_alpha"].real)
+    atomsDict["Fb_alpha_im"] = np.float32(atomsDict["Fb_alpha"].imag)
 
     return atomsDict
 
 
-def _get_absolute_kernel_path ( kernel ):
+def _get_absolute_kernel_path(kernel):
     pyfstatdir = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
-    kernelfile = kernel + '.cu'
-    return os.path.join(pyfstatdir,'pyCUDAkernels',kernelfile)
+    kernelfile = kernel + ".cu"
+    return os.path.join(pyfstatdir, "pyCUDAkernels", kernelfile)
 
 
-def _print_GPU_memory_MB ( key ):
-    mem_used_MB  = drv.mem_get_info()[0]/(2.**20)
-    mem_total_MB = drv.mem_get_info()[1]/(2.**20)
-    logging.debug('{} GPU memory: {:.4f} / {:.4f} MB free'.format(
-                      key, mem_used_MB, mem_total_MB))
+def _print_GPU_memory_MB(key):
+    mem_used_MB = drv.mem_get_info()[0] / (2.0 ** 20)
+    mem_total_MB = drv.mem_get_info()[1] / (2.0 ** 20)
+    logging.debug(
+        "{} GPU memory: {:.4f} / {:.4f} MB free".format(key, mem_used_MB, mem_total_MB)
+    )
 
 
-def pycuda_compute_transient_fstat_map ( multiFstatAtoms, windowRange ):
-    '''
+def pycuda_compute_transient_fstat_map(multiFstatAtoms, windowRange):
+    """
     GPU version of the function to compute transient-window "F-statistic map"
     over start-time and timescale {t0, tau}.
     Based on XLALComputeTransientFstatMap from LALSuite,
@@ -255,59 +285,67 @@ def pycuda_compute_transient_fstat_map ( multiFstatAtoms, windowRange ):
     in steps of dt0  in [t0,  t0+t0Band],
     and         dtau in [tau, tau+tauBand]
     as defined in windowRange input.
-    '''
+    """
 
-    if ( windowRange.type >= lalpulsar.TRANSIENT_LAST ):
-        raise ValueError ('Unknown window-type ({}) passed as input.' \
-                          ' Allowed are [0,{}].'.format(
-                              windowRange.type, lalpulsar.TRANSIENT_LAST-1))
+    if windowRange.type >= lalpulsar.TRANSIENT_LAST:
+        raise ValueError(
+            "Unknown window-type ({}) passed as input."
+            " Allowed are [0,{}].".format(
+                windowRange.type, lalpulsar.TRANSIENT_LAST - 1
+            )
+        )
 
     # internal dict for search/setup parameters
     tCWparams = {}
 
     # first combine all multi-atoms
     # into a single atoms-vector with *unique* timestamps
-    tCWparams['TAtom'] = multiFstatAtoms.data[0].TAtom
-    TAtomHalf          = int(tCWparams['TAtom']/2) # integer division
-    atoms = lalpulsar.mergeMultiFstatAtomsBinned ( multiFstatAtoms,
-                                                   tCWparams['TAtom'] )
+    tCWparams["TAtom"] = multiFstatAtoms.data[0].TAtom
+    TAtomHalf = int(tCWparams["TAtom"] / 2)  # integer division
+    atoms = lalpulsar.mergeMultiFstatAtomsBinned(multiFstatAtoms, tCWparams["TAtom"])
 
     # make a combined input matrix of all atoms vectors, for transfer to GPU
-    tCWparams['numAtoms'] = atoms.length
+    tCWparams["numAtoms"] = atoms.length
     atomsDict = reshape_FstatAtomsVector(atoms)
-    atomsInputMatrix = np.column_stack ( (atomsDict['a2_alpha'],
-                                          atomsDict['b2_alpha'],
-                                          atomsDict['ab_alpha'],
-                                          atomsDict['Fa_alpha_re'],
-                                          atomsDict['Fa_alpha_im'],
-                                          atomsDict['Fb_alpha_re'],
-                                          atomsDict['Fb_alpha_im'])
-                                       )
+    atomsInputMatrix = np.column_stack(
+        (
+            atomsDict["a2_alpha"],
+            atomsDict["b2_alpha"],
+            atomsDict["ab_alpha"],
+            atomsDict["Fa_alpha_re"],
+            atomsDict["Fa_alpha_im"],
+            atomsDict["Fb_alpha_re"],
+            atomsDict["Fb_alpha_im"],
+        )
+    )
 
     # actual data spans [t0_data, t0_data + tCWparams['numAtoms'] * TAtom]
     # in steps of TAtom
-    tCWparams['t0_data'] = int(atoms.data[0].timestamp)
-    tCWparams['t1_data'] = int(atoms.data[tCWparams['numAtoms']-1].timestamp
-                               + tCWparams['TAtom'])
-
-    logging.debug('Transient F-stat map:' \
-                  ' t0_data={:d}, t1_data={:d}'.format(
-                      tCWparams['t0_data'], tCWparams['t1_data']))
-    logging.debug('Transient F-stat map:' \
-                  ' numAtoms={:d}, TAtom={:d},' \
-                  ' TAtomHalf={:d}'.format(
-                      tCWparams['numAtoms'], tCWparams['TAtom'], TAtomHalf))
+    tCWparams["t0_data"] = int(atoms.data[0].timestamp)
+    tCWparams["t1_data"] = int(
+        atoms.data[tCWparams["numAtoms"] - 1].timestamp + tCWparams["TAtom"]
+    )
+
+    logging.debug(
+        "Transient F-stat map:"
+        " t0_data={:d}, t1_data={:d}".format(tCWparams["t0_data"], tCWparams["t1_data"])
+    )
+    logging.debug(
+        "Transient F-stat map:"
+        " numAtoms={:d}, TAtom={:d},"
+        " TAtomHalf={:d}".format(tCWparams["numAtoms"], tCWparams["TAtom"], TAtomHalf)
+    )
 
     # special treatment of window_type = none
     # ==> replace by rectangular window spanning all the data
-    if ( windowRange.type == lalpulsar.TRANSIENT_NONE ):
-        windowRange.type    = lalpulsar.TRANSIENT_RECTANGULAR
-        windowRange.t0      = tCWparams['t0_data']
-        windowRange.t0Band  = 0
-        windowRange.dt0     = tCWparams['TAtom'] # irrelevant
-        windowRange.tau     = tCWparams['numAtoms'] * tCWparams['TAtom']
-        windowRange.tauBand = 0;
-        windowRange.dtau    = tCWparams['TAtom'] # irrelevant
+    if windowRange.type == lalpulsar.TRANSIENT_NONE:
+        windowRange.type = lalpulsar.TRANSIENT_RECTANGULAR
+        windowRange.t0 = tCWparams["t0_data"]
+        windowRange.t0Band = 0
+        windowRange.dt0 = tCWparams["TAtom"]  # irrelevant
+        windowRange.tau = tCWparams["numAtoms"] * tCWparams["TAtom"]
+        windowRange.tauBand = 0
+        windowRange.dtau = tCWparams["TAtom"]  # irrelevant
 
     """ NOTE: indices {i,j} enumerate *actual* atoms and their timestamps t_i,
     * while the indices {m,n} enumerate the full grid of values
@@ -333,147 +371,177 @@ def pycuda_compute_transient_fstat_map ( multiFstatAtoms, windowRange ):
 
     # We allocate a matrix  {m x n} = t0Range * TcohRange elements
     # covering the full transient window-range [t0,t0+t0Band]x[tau,tau+tauBand]
-    tCWparams['N_t0Range']  = int(np.floor( 1.0*windowRange.t0Band /
-                                            windowRange.dt0 ) + 1)
-    tCWparams['N_tauRange'] = int(np.floor( 1.0*windowRange.tauBand /
-                                            windowRange.dtau ) + 1)
-    FstatMap = pyTransientFstatMap ( tCWparams['N_t0Range'],
-                                     tCWparams['N_tauRange'] )
-
-    logging.debug('Transient F-stat map:' \
-                  ' N_t0Range={:d}, N_tauRange={:d},' \
-                  ' total grid points: {:d}'.format(
-                      tCWparams['N_t0Range'], tCWparams['N_tauRange'],
-                      tCWparams['N_t0Range']*tCWparams['N_tauRange']))
-
-    if ( windowRange.type == lalpulsar.TRANSIENT_RECTANGULAR ):
-        FstatMap.F_mn = pycuda_compute_transient_fstat_map_rect (
-                           atomsInputMatrix, windowRange, tCWparams )
-    elif ( windowRange.type == lalpulsar.TRANSIENT_EXPONENTIAL ):
-        FstatMap.F_mn = pycuda_compute_transient_fstat_map_exp (
-                           atomsInputMatrix, windowRange, tCWparams )
+    tCWparams["N_t0Range"] = int(
+        np.floor(1.0 * windowRange.t0Band / windowRange.dt0) + 1
+    )
+    tCWparams["N_tauRange"] = int(
+        np.floor(1.0 * windowRange.tauBand / windowRange.dtau) + 1
+    )
+    FstatMap = pyTransientFstatMap(tCWparams["N_t0Range"], tCWparams["N_tauRange"])
+
+    logging.debug(
+        "Transient F-stat map:"
+        " N_t0Range={:d}, N_tauRange={:d},"
+        " total grid points: {:d}".format(
+            tCWparams["N_t0Range"],
+            tCWparams["N_tauRange"],
+            tCWparams["N_t0Range"] * tCWparams["N_tauRange"],
+        )
+    )
+
+    if windowRange.type == lalpulsar.TRANSIENT_RECTANGULAR:
+        FstatMap.F_mn = pycuda_compute_transient_fstat_map_rect(
+            atomsInputMatrix, windowRange, tCWparams
+        )
+    elif windowRange.type == lalpulsar.TRANSIENT_EXPONENTIAL:
+        FstatMap.F_mn = pycuda_compute_transient_fstat_map_exp(
+            atomsInputMatrix, windowRange, tCWparams
+        )
     else:
-        raise ValueError('Invalid transient window type {}' \
-                         ' not in [{}, {}].'.format(
-                            windowRange.type, lalpulsar.TRANSIENT_NONE,
-                            lalpulsar.TRANSIENT_LAST-1))
+        raise ValueError(
+            "Invalid transient window type {}"
+            " not in [{}, {}].".format(
+                windowRange.type, lalpulsar.TRANSIENT_NONE, lalpulsar.TRANSIENT_LAST - 1
+            )
+        )
 
     # out of loop: get max2F and ML estimates over the m x n matrix
     FstatMap.maxF = FstatMap.F_mn.max()
-    maxidx = np.unravel_index ( FstatMap.F_mn.argmax(),
-                               (tCWparams['N_t0Range'],
-                                tCWparams['N_tauRange']))
-    FstatMap.t0_ML  = windowRange.t0  + maxidx[0] * windowRange.dt0
+    maxidx = np.unravel_index(
+        FstatMap.F_mn.argmax(), (tCWparams["N_t0Range"], tCWparams["N_tauRange"])
+    )
+    FstatMap.t0_ML = windowRange.t0 + maxidx[0] * windowRange.dt0
     FstatMap.tau_ML = windowRange.tau + maxidx[1] * windowRange.dtau
 
-    logging.debug('Done computing transient F-stat map.' \
-                  ' maxF={:.4f}, t0_ML={}, tau_ML={}'.format(
-                      FstatMap.maxF , FstatMap.t0_ML, FstatMap.tau_ML))
+    logging.debug(
+        "Done computing transient F-stat map."
+        " maxF={:.4f}, t0_ML={}, tau_ML={}".format(
+            FstatMap.maxF, FstatMap.t0_ML, FstatMap.tau_ML
+        )
+    )
 
     return FstatMap
 
 
-def pycuda_compute_transient_fstat_map_rect ( atomsInputMatrix,
-                                              windowRange,
-                                              tCWparams ):
-    '''
+def pycuda_compute_transient_fstat_map_rect(atomsInputMatrix, windowRange, tCWparams):
+    """
     only GPU-parallizing outer loop,
     keeping partial sums with memory in kernel
-    '''
+    """
 
     # gpu data setup and transfer
-    _print_GPU_memory_MB('Initial')
-    input_gpu = gpuarray.to_gpu ( atomsInputMatrix )
-    Fmn_gpu   = gpuarray.GPUArray ( (tCWparams['N_t0Range'],
-                                     tCWparams['N_tauRange']),
-                                    dtype=np.float32 )
-    _print_GPU_memory_MB('After input+output allocation:')
+    _print_GPU_memory_MB("Initial")
+    input_gpu = gpuarray.to_gpu(atomsInputMatrix)
+    Fmn_gpu = gpuarray.GPUArray(
+        (tCWparams["N_t0Range"], tCWparams["N_tauRange"]), dtype=np.float32
+    )
+    _print_GPU_memory_MB("After input+output allocation:")
 
     # GPU kernel
-    kernel = 'cudaTransientFstatRectWindow'
-    kernelfile = _get_absolute_kernel_path ( kernel )
-    partial_Fstat_cuda_code = cudacomp.SourceModule(open(kernelfile,'r').read())
+    kernel = "cudaTransientFstatRectWindow"
+    kernelfile = _get_absolute_kernel_path(kernel)
+    partial_Fstat_cuda_code = cudacomp.SourceModule(open(kernelfile, "r").read())
     partial_Fstat_cuda = partial_Fstat_cuda_code.get_function(kernel)
-    partial_Fstat_cuda.prepare('PIIIIIIIIP')
+    partial_Fstat_cuda.prepare("PIIIIIIIIP")
 
     # GPU grid setup
-    blockRows = min(1024,tCWparams['N_t0Range'])
+    blockRows = min(1024, tCWparams["N_t0Range"])
     blockCols = 1
-    gridRows  = int(np.ceil(1.0*tCWparams['N_t0Range']/blockRows))
-    gridCols  = 1
+    gridRows = int(np.ceil(1.0 * tCWparams["N_t0Range"] / blockRows))
+    gridCols = 1
 
     # running the kernel
-    logging.debug('Calling pyCUDA kernel with a grid of {}*{}={} blocks' \
-                  ' of {}*{}={} threads each: {} total threads...'.format(
-                      gridRows, gridCols, gridRows*gridCols,
-                      blockRows, blockCols, blockRows*blockCols,
-                      gridRows*gridCols*blockRows*blockCols))
-    partial_Fstat_cuda.prepared_call ( (gridRows,gridCols),
-                                       (blockRows,blockCols,1),
-                                       input_gpu.gpudata,
-                                       tCWparams['numAtoms'],
-                                       tCWparams['TAtom'],
-                                       tCWparams['t0_data'],
-                                       windowRange.t0, windowRange.dt0,
-                                       windowRange.tau, windowRange.dtau,
-                                       tCWparams['N_tauRange'],
-                                       Fmn_gpu.gpudata )
+    logging.debug(
+        "Calling pyCUDA kernel with a grid of {}*{}={} blocks"
+        " of {}*{}={} threads each: {} total threads...".format(
+            gridRows,
+            gridCols,
+            gridRows * gridCols,
+            blockRows,
+            blockCols,
+            blockRows * blockCols,
+            gridRows * gridCols * blockRows * blockCols,
+        )
+    )
+    partial_Fstat_cuda.prepared_call(
+        (gridRows, gridCols),
+        (blockRows, blockCols, 1),
+        input_gpu.gpudata,
+        tCWparams["numAtoms"],
+        tCWparams["TAtom"],
+        tCWparams["t0_data"],
+        windowRange.t0,
+        windowRange.dt0,
+        windowRange.tau,
+        windowRange.dtau,
+        tCWparams["N_tauRange"],
+        Fmn_gpu.gpudata,
+    )
 
     # return results to host
     F_mn = Fmn_gpu.get()
 
-    _print_GPU_memory_MB('Final')
+    _print_GPU_memory_MB("Final")
 
     return F_mn
 
 
-def pycuda_compute_transient_fstat_map_exp ( atomsInputMatrix,
-                                             windowRange,
-                                             tCWparams ):
-    '''exponential window, inner and outer loop GPU-parallelized'''
+def pycuda_compute_transient_fstat_map_exp(atomsInputMatrix, windowRange, tCWparams):
+    """exponential window, inner and outer loop GPU-parallelized"""
 
     # gpu data setup and transfer
-    _print_GPU_memory_MB('Initial')
-    input_gpu = gpuarray.to_gpu ( atomsInputMatrix )
-    Fmn_gpu   = gpuarray.GPUArray ( (tCWparams['N_t0Range'],
-                                     tCWparams['N_tauRange']),
-                                    dtype=np.float32 )
-    _print_GPU_memory_MB('After input+output allocation:')
+    _print_GPU_memory_MB("Initial")
+    input_gpu = gpuarray.to_gpu(atomsInputMatrix)
+    Fmn_gpu = gpuarray.GPUArray(
+        (tCWparams["N_t0Range"], tCWparams["N_tauRange"]), dtype=np.float32
+    )
+    _print_GPU_memory_MB("After input+output allocation:")
 
     # GPU kernel
-    kernel = 'cudaTransientFstatExpWindow'
-    kernelfile = _get_absolute_kernel_path ( kernel )
-    partial_Fstat_cuda_code = cudacomp.SourceModule(open(kernelfile,'r').read())
+    kernel = "cudaTransientFstatExpWindow"
+    kernelfile = _get_absolute_kernel_path(kernel)
+    partial_Fstat_cuda_code = cudacomp.SourceModule(open(kernelfile, "r").read())
     partial_Fstat_cuda = partial_Fstat_cuda_code.get_function(kernel)
-    partial_Fstat_cuda.prepare('PIIIIIIIIIP')
+    partial_Fstat_cuda.prepare("PIIIIIIIIIP")
 
     # GPU grid setup
-    blockRows = min(32,tCWparams['N_t0Range'])
-    blockCols = min(32,tCWparams['N_tauRange'])
-    gridRows  = int(np.ceil(1.0*tCWparams['N_t0Range']/blockRows))
-    gridCols  = int(np.ceil(1.0*tCWparams['N_tauRange']/blockCols))
+    blockRows = min(32, tCWparams["N_t0Range"])
+    blockCols = min(32, tCWparams["N_tauRange"])
+    gridRows = int(np.ceil(1.0 * tCWparams["N_t0Range"] / blockRows))
+    gridCols = int(np.ceil(1.0 * tCWparams["N_tauRange"] / blockCols))
 
     # running the kernel
-    logging.debug('Calling kernel with a grid of {}*{}={} blocks' \
-                  ' of {}*{}={} threads each: {} total threads...'.format(
-                      gridRows, gridCols, gridRows*gridCols,
-                      blockRows, blockCols, blockRows*blockCols,
-                      gridRows*gridCols*blockRows*blockCols))
-    partial_Fstat_cuda.prepared_call ( (gridRows,gridCols),
-                                       (blockRows,blockCols,1),
-                                       input_gpu.gpudata,
-                                       tCWparams['numAtoms'],
-                                       tCWparams['TAtom'],
-                                       tCWparams['t0_data'],
-                                       windowRange.t0, windowRange.dt0,
-                                       windowRange.tau, windowRange.dtau,
-                                       tCWparams['N_t0Range'],
-                                       tCWparams['N_tauRange'],
-                                       Fmn_gpu.gpudata )
+    logging.debug(
+        "Calling kernel with a grid of {}*{}={} blocks"
+        " of {}*{}={} threads each: {} total threads...".format(
+            gridRows,
+            gridCols,
+            gridRows * gridCols,
+            blockRows,
+            blockCols,
+            blockRows * blockCols,
+            gridRows * gridCols * blockRows * blockCols,
+        )
+    )
+    partial_Fstat_cuda.prepared_call(
+        (gridRows, gridCols),
+        (blockRows, blockCols, 1),
+        input_gpu.gpudata,
+        tCWparams["numAtoms"],
+        tCWparams["TAtom"],
+        tCWparams["t0_data"],
+        windowRange.t0,
+        windowRange.dt0,
+        windowRange.tau,
+        windowRange.dtau,
+        tCWparams["N_t0Range"],
+        tCWparams["N_tauRange"],
+        Fmn_gpu.gpudata,
+    )
 
     # return results to host
     F_mn = Fmn_gpu.get()
 
-    _print_GPU_memory_MB('Final')
+    _print_GPU_memory_MB("Final")
 
     return F_mn
diff --git a/requirements.txt b/requirements.txt
index eee946d288fb372c955809cc3507ab6a33dbc667..a840f31b2b5d3ed6d1c4914138b9acb7d7f1fc85 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,4 @@ tqdm
 bashplotlib
 peakutils
 pathos
+pycuda
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..bb7f7ce70e4ed5ed556eeb153b8a9119c9304dde
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,8 @@
+[metadata]
+license_file = LICENSE.txt
+
+[flake8]
+exclude = .git,docs,build,dist,tests,*__init__.py
+max-line-length = 80
+select = C,E,F,W,B,B950
+ignore = E501,W503,E203
diff --git a/setup.py b/setup.py
index d9557b966801542f44a36fd45a2f27d1fbb07ed1..7e8c8637bdaae3bb74b9eafae0db449007d51ec3 100644
--- a/setup.py
+++ b/setup.py
@@ -5,27 +5,32 @@ from os import path
 
 here = path.abspath(path.dirname(__file__))
 # Get the long description from the README file
-with open(path.join(here, 'README.md'), encoding='utf-8') as f:
+with open(path.join(here, "README.md"), encoding="utf-8") as f:
     long_description = f.read()
 
-setup(name='PyFstat',
-      version='0.2',
-      author='Gregory Ashton',
-      author_email='gregory.ashton@ligo.org',
-      packages=find_packages(where="pyfstat"),
-      include_package_data=True,
-      package_data={'pyfstat': ['pyCUDAkernels/cudaTransientFstatExpWindow.cu',
-                                'pyCUDAkernels/cudaTransientFstatRectWindow.cu']},
-      install_requires=[
-          'matplotlib',
-          'scipy',
-          'ptemcee',
-          'corner',
-          'dill',
-          'tqdm',
-          'bashplotlib',
-          'peakutils',
-          'pathos',
-          'pycuda',
-      ],
+setup(
+    name="PyFstat",
+    version="0.2",
+    author="Gregory Ashton",
+    author_email="gregory.ashton@ligo.org",
+    packages=find_packages(where="pyfstat"),
+    include_package_data=True,
+    package_data={
+        "pyfstat": [
+            "pyCUDAkernels/cudaTransientFstatExpWindow.cu",
+            "pyCUDAkernels/cudaTransientFstatRectWindow.cu",
+        ]
+    },
+    install_requires=[
+        "matplotlib",
+        "scipy",
+        "ptemcee",
+        "corner",
+        "dill",
+        "tqdm",
+        "bashplotlib",
+        "peakutils",
+        "pathos",
+        "pycuda",
+    ],
 )
diff --git a/tests.py b/tests.py
index c67e40dbb48cf4a212bfb2bef627794e4849a16f..a9822733e84bf182ac65c0ab23f37375fd567b4a 100644
--- a/tests.py
+++ b/tests.py
@@ -8,7 +8,7 @@ import logging
 
 
 class Test(unittest.TestCase):
-    outdir = 'TestData'
+    outdir = "TestData"
 
     @classmethod
     def setUpClass(self):
@@ -16,8 +16,7 @@ class Test(unittest.TestCase):
             try:
                 shutil.rmtree(self.outdir)
             except OSError:
-                logging.warning(
-                    "{} not removed prior to tests".format(self.outdir))
+                logging.warning("{} not removed prior to tests".format(self.outdir))
         h0 = 1
         sqrtSX = 1
         F0 = 30
@@ -28,12 +27,21 @@ class Test(unittest.TestCase):
         Alpha = 5e-3
         Delta = 1.2
         tref = minStartTime
-        Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label='test',
-                                h0=h0, sqrtSX=sqrtSX,
-                                outdir=self.outdir, tstart=minStartTime,
-                                Alpha=Alpha, Delta=Delta, tref=tref,
-                                duration=duration,
-                                Band=4)
+        Writer = pyfstat.Writer(
+            F0=F0,
+            F1=F1,
+            F2=F2,
+            label="test",
+            h0=h0,
+            sqrtSX=sqrtSX,
+            outdir=self.outdir,
+            tstart=minStartTime,
+            Alpha=Alpha,
+            Delta=Delta,
+            tref=tref,
+            duration=duration,
+            Band=4,
+        )
         Writer.make_data()
         self.sftfilepath = Writer.sftfilepath
         self.minStartTime = minStartTime
@@ -46,8 +54,7 @@ class Test(unittest.TestCase):
             try:
                 shutil.rmtree(self.outdir)
             except OSError:
-                logging.warning(
-                    "{} not removed prior to tests".format(self.outdir))
+                logging.warning("{} not removed prior to tests".format(self.outdir))
 
 
 class Writer(Test):
@@ -56,16 +63,17 @@ class Writer(Test):
     def test_make_cff(self):
         Writer = pyfstat.Writer(self.label, outdir=self.outdir)
         Writer.make_cff()
-        self.assertTrue(os.path.isfile(
-            './{}/{}.cff'.format(self.outdir, self.label)))
+        self.assertTrue(os.path.isfile("./{}/{}.cff".format(self.outdir, self.label)))
 
     def test_run_makefakedata(self):
         Writer = pyfstat.Writer(self.label, outdir=self.outdir, duration=3600)
         Writer.make_cff()
         Writer.run_makefakedata()
-        self.assertTrue(os.path.isfile(
-            './{}/H-2_H1_1800SFT_TestWriter-700000000-3600.sft'
-            .format(self.outdir)))
+        self.assertTrue(
+            os.path.isfile(
+                "./{}/H-2_H1_1800SFT_TestWriter-700000000-3600.sft".format(self.outdir)
+            )
+        )
 
     def test_makefakedata_usecached(self):
         Writer = pyfstat.Writer(self.label, outdir=self.outdir, duration=3600)
@@ -77,7 +85,7 @@ class Writer(Test):
         Writer.run_makefakedata()
         time_second = os.path.getmtime(Writer.sftfilepath)
         self.assertTrue(time_first == time_second)
-        os.system('touch {}'.format(Writer.config_file_name))
+        os.system("touch {}".format(Writer.config_file_name))
         Writer.run_makefakedata()
         time_third = os.path.getmtime(Writer.sftfilepath)
         self.assertFalse(time_first == time_third)
@@ -90,22 +98,23 @@ class Bunch(Test):
 
 
 class par(Test):
-    label = 'TestPar'
+    label = "TestPar"
 
     def test(self):
-        os.system(
-            'echo "x=100\ny=10" > {}/{}.par'.format(self.outdir, self.label))
+        os.system('echo "x=100\ny=10" > {}/{}.par'.format(self.outdir, self.label))
 
         par = pyfstat.core.read_par(
-            '{}/{}.par'.format(self.outdir, self.label), return_type='Bunch')
+            "{}/{}.par".format(self.outdir, self.label), return_type="Bunch"
+        )
         self.assertTrue(par.x == 100)
         self.assertTrue(par.y == 10)
 
-        par = pyfstat.core.read_par(outdir=self.outdir, label=self.label,
-                                    return_type='dict')
-        self.assertTrue(par['x'] == 100)
-        self.assertTrue(par['y'] == 10)
-        os.system('rm -r {}'.format(self.outdir))
+        par = pyfstat.core.read_par(
+            outdir=self.outdir, label=self.label, return_type="dict"
+        )
+        self.assertTrue(par["x"] == 100)
+        self.assertTrue(par["y"] == 10)
+        os.system("rm -r {}".format(self.outdir))
 
 
 class BaseSearchClass(Test):
@@ -113,115 +122,185 @@ class BaseSearchClass(Test):
         BSC = pyfstat.BaseSearchClass()
         dT = 10
         a = BSC._shift_matrix(4, dT)
-        b = np.array([[1, 2*np.pi*dT, 2*np.pi*dT**2/2.0, 2*np.pi*dT**3/6.0],
-                      [0, 1, dT, dT**2/2.0],
-                      [0, 0, 1, dT],
-                      [0, 0, 0, 1]])
+        b = np.array(
+            [
+                [
+                    1,
+                    2 * np.pi * dT,
+                    2 * np.pi * dT ** 2 / 2.0,
+                    2 * np.pi * dT ** 3 / 6.0,
+                ],
+                [0, 1, dT, dT ** 2 / 2.0],
+                [0, 0, 1, dT],
+                [0, 0, 0, 1],
+            ]
+        )
         self.assertTrue(np.array_equal(a, b))
 
     def test_shift_coefficients(self):
         BSC = pyfstat.BaseSearchClass()
-        thetaA = np.array([10., 1e2, 10., 1e2])
+        thetaA = np.array([10.0, 1e2, 10.0, 1e2])
         dT = 100
 
         # Calculate the 'long' way
         thetaB = np.zeros(len(thetaA))
         thetaB[3] = thetaA[3]
-        thetaB[2] = thetaA[2] + thetaA[3]*dT
-        thetaB[1] = thetaA[1] + thetaA[2]*dT + .5*thetaA[3]*dT**2
-        thetaB[0] = thetaA[0] + 2*np.pi*(thetaA[1]*dT + .5*thetaA[2]*dT**2
-                                         + thetaA[3]*dT**3 / 6.0)
+        thetaB[2] = thetaA[2] + thetaA[3] * dT
+        thetaB[1] = thetaA[1] + thetaA[2] * dT + 0.5 * thetaA[3] * dT ** 2
+        thetaB[0] = thetaA[0] + 2 * np.pi * (
+            thetaA[1] * dT + 0.5 * thetaA[2] * dT ** 2 + thetaA[3] * dT ** 3 / 6.0
+        )
 
-        self.assertTrue(
-            np.array_equal(
-                thetaB, BSC._shift_coefficients(thetaA, dT)))
+        self.assertTrue(np.array_equal(thetaB, BSC._shift_coefficients(thetaA, dT)))
 
     def test_shift_coefficients_loop(self):
         BSC = pyfstat.BaseSearchClass()
-        thetaA = np.array([10., 1e2, 10., 1e2])
+        thetaA = np.array([10.0, 1e2, 10.0, 1e2])
         dT = 1e1
         thetaB = BSC._shift_coefficients(thetaA, dT)
         self.assertTrue(
             np.allclose(
-                thetaA, BSC._shift_coefficients(thetaB, -dT),
-                rtol=1e-9, atol=1e-9))
+                thetaA, BSC._shift_coefficients(thetaB, -dT), rtol=1e-9, atol=1e-9
+            )
+        )
 
 
 class ComputeFstat(Test):
     label = "TestComputeFstat"
 
     def test_run_computefstatistic_single_point(self):
-        Writer = pyfstat.Writer(self.label, outdir=self.outdir, duration=86400,
-                                h0=1, sqrtSX=1, detectors='H1')
+        Writer = pyfstat.Writer(
+            self.label,
+            outdir=self.outdir,
+            duration=86400,
+            h0=1,
+            sqrtSX=1,
+            detectors="H1",
+        )
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
         search_H1L1 = pyfstat.ComputeFstat(
             tref=Writer.tref,
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label))
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+        )
         FS = search_H1L1.get_fullycoherent_twoF(
-            Writer.tstart, Writer.tend, Writer.F0, Writer.F1, Writer.F2,
-            Writer.Alpha, Writer.Delta)
-        self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.3)
-
-        Writer.detectors = 'H1'
+            Writer.tstart,
+            Writer.tend,
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+        )
+        self.assertTrue(np.abs(predicted_FS - FS) / FS < 0.3)
+
+        Writer.detectors = "H1"
         predicted_FS = Writer.predict_fstat()
         search_H1 = pyfstat.ComputeFstat(
-            tref=Writer.tref, detectors='H1',
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
-            SSBprec=lalpulsar.SSBPREC_RELATIVISTIC)
+            tref=Writer.tref,
+            detectors="H1",
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+            SSBprec=lalpulsar.SSBPREC_RELATIVISTIC,
+        )
         FS = search_H1.get_fullycoherent_twoF(
-            Writer.tstart, Writer.tend, Writer.F0, Writer.F1, Writer.F2,
-            Writer.Alpha, Writer.Delta)
-        self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.3)
+            Writer.tstart,
+            Writer.tend,
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+        )
+        self.assertTrue(np.abs(predicted_FS - FS) / FS < 0.3)
 
     def run_computefstatistic_single_point_no_noise(self):
         Writer = pyfstat.Writer(
-            self.label, outdir=self.outdir, add_noise=False, duration=86400,
-            h0=1, sqrtSX=1)
+            self.label,
+            outdir=self.outdir,
+            add_noise=False,
+            duration=86400,
+            h0=1,
+            sqrtSX=1,
+        )
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
         search = pyfstat.ComputeFstat(
-            tref=Writer.tref, assumeSqrtSX=1,
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label))
+            tref=Writer.tref,
+            assumeSqrtSX=1,
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+        )
         FS = search.get_fullycoherent_twoF(
-            Writer.tstart, Writer.tend, Writer.F0, Writer.F1, Writer.F2,
-            Writer.Alpha, Writer.Delta)
-        self.assertTrue(np.abs(predicted_FS-FS)/FS < 0.3)
+            Writer.tstart,
+            Writer.tend,
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+        )
+        self.assertTrue(np.abs(predicted_FS - FS) / FS < 0.3)
 
     def test_injectSources(self):
         # This seems to be writing with a signal...
         Writer = pyfstat.Writer(
-            self.label, outdir=self.outdir, add_noise=False, duration=86400,
-            h0=1, sqrtSX=1)
+            self.label,
+            outdir=self.outdir,
+            add_noise=False,
+            duration=86400,
+            h0=1,
+            sqrtSX=1,
+        )
         Writer.make_cff()
         injectSources = Writer.config_file_name
 
         search = pyfstat.ComputeFstat(
-            tref=Writer.tref, assumeSqrtSX=1, injectSources=injectSources,
-            minCoverFreq=28, maxCoverFreq=32, minStartTime=Writer.tstart,
-            maxStartTime=Writer.tstart+Writer.duration,
-            detectors=Writer.detectors)
+            tref=Writer.tref,
+            assumeSqrtSX=1,
+            injectSources=injectSources,
+            minCoverFreq=28,
+            maxCoverFreq=32,
+            minStartTime=Writer.tstart,
+            maxStartTime=Writer.tstart + Writer.duration,
+            detectors=Writer.detectors,
+        )
         FS_from_file = search.get_fullycoherent_twoF(
-            Writer.tstart, Writer.tend, Writer.F0, Writer.F1, Writer.F2,
-            Writer.Alpha, Writer.Delta)
+            Writer.tstart,
+            Writer.tend,
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+        )
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
-        self.assertTrue(np.abs(predicted_FS-FS_from_file)/FS_from_file < 0.3)
+        self.assertTrue(np.abs(predicted_FS - FS_from_file) / FS_from_file < 0.3)
 
         injectSourcesdict = pyfstat.core.read_par(Writer.config_file_name)
-        injectSourcesdict['F0'] = injectSourcesdict['Freq']
-        injectSourcesdict['F1'] = injectSourcesdict['f1dot']
-        injectSourcesdict['F2'] = injectSourcesdict['f2dot']
+        injectSourcesdict["F0"] = injectSourcesdict["Freq"]
+        injectSourcesdict["F1"] = injectSourcesdict["f1dot"]
+        injectSourcesdict["F2"] = injectSourcesdict["f2dot"]
         search = pyfstat.ComputeFstat(
-            tref=Writer.tref, assumeSqrtSX=1, injectSources=injectSourcesdict,
-            minCoverFreq=28, maxCoverFreq=32, minStartTime=Writer.tstart,
-            maxStartTime=Writer.tstart+Writer.duration,
-            detectors=Writer.detectors)
+            tref=Writer.tref,
+            assumeSqrtSX=1,
+            injectSources=injectSourcesdict,
+            minCoverFreq=28,
+            maxCoverFreq=32,
+            minStartTime=Writer.tstart,
+            maxStartTime=Writer.tstart + Writer.duration,
+            detectors=Writer.detectors,
+        )
         FS_from_dict = search.get_fullycoherent_twoF(
-            Writer.tstart, Writer.tend, Writer.F0, Writer.F1, Writer.F2,
-            Writer.Alpha, Writer.Delta)
+            Writer.tstart,
+            Writer.tend,
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+        )
         self.assertTrue(FS_from_dict == FS_from_file)
 
 
@@ -229,20 +308,30 @@ class SemiCoherentSearch(Test):
     label = "TestSemiCoherentSearch"
 
     def test_get_semicoherent_twoF(self):
-        duration = 10*86400
+        duration = 10 * 86400
         Writer = pyfstat.Writer(
-            self.label, outdir=self.outdir, duration=duration, h0=1, sqrtSX=1)
+            self.label, outdir=self.outdir, duration=duration, h0=1, sqrtSX=1
+        )
         Writer.make_data()
 
         search = pyfstat.SemiCoherentSearch(
-            label=self.label, outdir=self.outdir, nsegs=2,
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
-            tref=Writer.tref, minStartTime=Writer.tstart,
-            maxStartTime=Writer.tend)
+            label=self.label,
+            outdir=self.outdir,
+            nsegs=2,
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+            tref=Writer.tref,
+            minStartTime=Writer.tstart,
+            maxStartTime=Writer.tend,
+        )
 
         search.get_semicoherent_twoF(
-            Writer.F0, Writer.F1, Writer.F2, Writer.Alpha, Writer.Delta,
-            record_segments=True)
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+            record_segments=True,
+        )
 
         # Compute the predicted semi-coherent Fstat
         minStartTime = Writer.tstart
@@ -260,21 +349,31 @@ class SemiCoherentSearch(Test):
         self.assertTrue(np.all(diffs < 0.3))
 
     def test_get_semicoherent_BSGL(self):
-        duration = 10*86400
+        duration = 10 * 86400
         Writer = pyfstat.Writer(
-            self.label, outdir=self.outdir, duration=duration,
-            detectors='H1,L1')
+            self.label, outdir=self.outdir, duration=duration, detectors="H1,L1"
+        )
         Writer.make_data()
 
         search = pyfstat.SemiCoherentSearch(
-            label=self.label, outdir=self.outdir, nsegs=2,
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
-            tref=Writer.tref, minStartTime=Writer.tstart,
-            maxStartTime=Writer.tend, BSGL=True)
+            label=self.label,
+            outdir=self.outdir,
+            nsegs=2,
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+            tref=Writer.tref,
+            minStartTime=Writer.tstart,
+            maxStartTime=Writer.tend,
+            BSGL=True,
+        )
 
         BSGL = search.get_semicoherent_twoF(
-            Writer.F0, Writer.F1, Writer.F2, Writer.Alpha, Writer.Delta,
-            record_segments=True)
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+            record_segments=True,
+        )
         self.assertTrue(BSGL > 0)
 
 
@@ -282,26 +381,43 @@ class SemiCoherentGlitchSearch(Test):
     label = "TestSemiCoherentGlitchSearch"
 
     def test_get_semicoherent_nglitch_twoF(self):
-        duration = 10*86400
-        dtglitch = .5*duration
+        duration = 10 * 86400
+        dtglitch = 0.5 * duration
         delta_F0 = 0
         h0 = 1
         sqrtSX = 1
         Writer = pyfstat.GlitchWriter(
-            self.label, outdir=self.outdir, duration=duration, dtglitch=dtglitch,
-            delta_F0=delta_F0, sqrtSX=sqrtSX, h0=h0)
+            self.label,
+            outdir=self.outdir,
+            duration=duration,
+            dtglitch=dtglitch,
+            delta_F0=delta_F0,
+            sqrtSX=sqrtSX,
+            h0=h0,
+        )
 
         Writer.make_data()
 
         search = pyfstat.SemiCoherentGlitchSearch(
-            label=self.label, outdir=self.outdir,
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
-            tref=Writer.tref, minStartTime=Writer.tstart,
-            maxStartTime=Writer.tend, nglitch=1)
+            label=self.label,
+            outdir=self.outdir,
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+            tref=Writer.tref,
+            minStartTime=Writer.tstart,
+            maxStartTime=Writer.tend,
+            nglitch=1,
+        )
 
         FS = search.get_semicoherent_nglitch_twoF(
-            Writer.F0, Writer.F1, Writer.F2, Writer.Alpha, Writer.Delta,
-            Writer.delta_F0, Writer.delta_F1, search.minStartTime+dtglitch)
+            Writer.F0,
+            Writer.F1,
+            Writer.F2,
+            Writer.Alpha,
+            Writer.Delta,
+            Writer.delta_F0,
+            Writer.delta_F1,
+            search.minStartTime + dtglitch,
+        )
 
         # Compute the predicted semi-coherent glitch Fstat
         minStartTime = Writer.tstart
@@ -315,10 +431,10 @@ class SemiCoherentGlitchSearch(Test):
         FSB = Writer.predict_fstat()
 
         print(FSA, FSB)
-        predicted_FS = (FSA + FSB)
+        predicted_FS = FSA + FSB
 
         print((predicted_FS, FS))
-        self.assertTrue(np.abs((FS - predicted_FS))/predicted_FS < 0.3)
+        self.assertTrue(np.abs((FS - predicted_FS)) / predicted_FS < 0.3)
 
 
 class MCMCSearch(Test):
@@ -336,32 +452,53 @@ class MCMCSearch(Test):
         Alpha = 5e-3
         Delta = 1.2
         tref = minStartTime
-        Writer = pyfstat.Writer(F0=F0, F1=F1, F2=F2, label=self.label,
-                                h0=h0, sqrtSX=sqrtSX,
-                                outdir=self.outdir, tstart=minStartTime,
-                                Alpha=Alpha, Delta=Delta, tref=tref,
-                                duration=duration,
-                                Band=4)
+        Writer = pyfstat.Writer(
+            F0=F0,
+            F1=F1,
+            F2=F2,
+            label=self.label,
+            h0=h0,
+            sqrtSX=sqrtSX,
+            outdir=self.outdir,
+            tstart=minStartTime,
+            Alpha=Alpha,
+            Delta=Delta,
+            tref=tref,
+            duration=duration,
+            Band=4,
+        )
 
         Writer.make_data()
         predicted_FS = Writer.predict_fstat()
 
-        theta = {'F0': {'type': 'norm', 'loc': F0, 'scale': np.abs(1e-10*F0)},
-                 'F1': {'type': 'norm', 'loc': F1, 'scale': np.abs(1e-10*F1)},
-                 'F2': F2, 'Alpha': Alpha, 'Delta': Delta}
+        theta = {
+            "F0": {"type": "norm", "loc": F0, "scale": np.abs(1e-10 * F0)},
+            "F1": {"type": "norm", "loc": F1, "scale": np.abs(1e-10 * F1)},
+            "F2": F2,
+            "Alpha": Alpha,
+            "Delta": Delta,
+        }
 
         search = pyfstat.MCMCSearch(
-            label=self.label, outdir=self.outdir, theta_prior=theta, tref=tref,
-            sftfilepattern='{}/*{}*sft'.format(Writer.outdir, Writer.label),
-            minStartTime=minStartTime, maxStartTime=maxStartTime,
-            nsteps=[100, 100], nwalkers=100, ntemps=2, log10beta_min=-1)
+            label=self.label,
+            outdir=self.outdir,
+            theta_prior=theta,
+            tref=tref,
+            sftfilepattern="{}/*{}*sft".format(Writer.outdir, Writer.label),
+            minStartTime=minStartTime,
+            maxStartTime=maxStartTime,
+            nsteps=[100, 100],
+            nwalkers=100,
+            ntemps=2,
+            log10beta_min=-1,
+        )
         search.run(create_plots=False)
         _, FS = search.get_max_twoF()
 
-        print(('Predicted twoF is {} while recovered is {}'.format(
-                predicted_FS, FS)))
+        print(("Predicted twoF is {} while recovered is {}".format(predicted_FS, FS)))
         self.assertTrue(
-            FS > predicted_FS or np.abs((FS-predicted_FS))/predicted_FS < 0.3)
+            FS > predicted_FS or np.abs((FS - predicted_FS)) / predicted_FS < 0.3
+        )
 
 
 class GridSearch(Test):
@@ -371,42 +508,88 @@ class GridSearch(Test):
 
     def test_grid_search(self):
         search = pyfstat.GridSearch(
-            'grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
-            F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref)
+            "grid_search",
+            self.outdir,
+            self.sftfilepath,
+            F0s=self.F0s,
+            F1s=[0],
+            F2s=[0],
+            Alphas=[0],
+            Deltas=[0],
+            tref=self.tref,
+        )
         search.run()
         self.assertTrue(os.path.isfile(search.out_file))
 
     def test_semicoherent_grid_search(self):
         search = pyfstat.GridSearch(
-            'sc_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
-            F1s=[0], F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref, nsegs=2)
+            "sc_grid_search",
+            self.outdir,
+            self.sftfilepath,
+            F0s=self.F0s,
+            F1s=[0],
+            F2s=[0],
+            Alphas=[0],
+            Deltas=[0],
+            tref=self.tref,
+            nsegs=2,
+        )
         search.run()
         self.assertTrue(os.path.isfile(search.out_file))
 
     def test_slice_grid_search(self):
         search = pyfstat.SliceGridSearch(
-            'slice_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
-            F1s=self.F1s, F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref,
-            Lambda0=[30, 0, 0, 0])
+            "slice_grid_search",
+            self.outdir,
+            self.sftfilepath,
+            F0s=self.F0s,
+            F1s=self.F1s,
+            F2s=[0],
+            Alphas=[0],
+            Deltas=[0],
+            tref=self.tref,
+            Lambda0=[30, 0, 0, 0],
+        )
         search.run()
-        self.assertTrue(os.path.isfile('{}/{}_slice_projection.png'
-                                       .format(search.outdir, search.label)))
+        self.assertTrue(
+            os.path.isfile(
+                "{}/{}_slice_projection.png".format(search.outdir, search.label)
+            )
+        )
 
     def test_glitch_grid_search(self):
         search = pyfstat.GridGlitchSearch(
-            'grid_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
-            F1s=self.F1s, F2s=[0], Alphas=[0], Deltas=[0], tref=self.tref,
-            tglitchs=[self.tref])
+            "grid_grid_search",
+            self.outdir,
+            self.sftfilepath,
+            F0s=self.F0s,
+            F1s=self.F1s,
+            F2s=[0],
+            Alphas=[0],
+            Deltas=[0],
+            tref=self.tref,
+            tglitchs=[self.tref],
+        )
         search.run()
         self.assertTrue(os.path.isfile(search.out_file))
 
     def test_sliding_window(self):
         search = pyfstat.FrequencySlidingWindow(
-            'grid_grid_search', self.outdir, self.sftfilepath, F0s=self.F0s,
-            F1=0, F2=0, Alpha=0, Delta=0, tref=self.tref,
-            minStartTime=self.minStartTime, maxStartTime=self.maxStartTime)
+            "grid_grid_search",
+            self.outdir,
+            self.sftfilepath,
+            F0s=self.F0s,
+            F1=0,
+            F2=0,
+            Alpha=0,
+            Delta=0,
+            tref=self.tref,
+            minStartTime=self.minStartTime,
+            maxStartTime=self.maxStartTime,
+        )
         search.run()
         self.assertTrue(os.path.isfile(search.out_file))
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     unittest.main()