diff --git a/pyfstat/grid_based_searches.py b/pyfstat/grid_based_searches.py
index 59e3b52b5dcee8a86bbaeed188a3008fa892d8d4..f0309940f4c8f3ee3eaf6d24d1697137b82fb41c 100644
--- a/pyfstat/grid_based_searches.py
+++ b/pyfstat/grid_based_searches.py
@@ -236,19 +236,20 @@ class GridSearch(BaseSearchClass):
 
 class GridUniformPriorSearch():
     def __init__(self, theta_prior, NF0, NF1, label, outdir, sftfilepath,
-                 tref, minStartTime, maxStartTime, BSGL=False, detectors=None,
-                 **kwargs):
+                 tref, minStartTime, maxStartTime, BSGL=False, detectors=None):
         dF0 = (theta_prior['F0']['upper'] - theta_prior['F0']['lower'])/NF0
         dF1 = (theta_prior['F1']['upper'] - theta_prior['F1']['lower'])/NF1
         F0s = [theta_prior['F0']['lower'], theta_prior['F0']['upper'], dF0]
         F1s = [theta_prior['F1']['lower'], theta_prior['F1']['upper'], dF1]
-        search = GridSearch(
+        self.search = GridSearch(
             label, outdir, sftfilepath, F0s=F0s, F1s=F1s, tref=tref,
             Alphas=[theta_prior['Alpha']], Deltas=[theta_prior['Delta']],
             minStartTime=minStartTime, maxStartTime=maxStartTime, BSGL=BSGL,
             detectors=detectors)
-        search.run()
-        search.plot_2D('F0', 'F1', **kwargs)
+
+    def run(self, **kwargs):
+        self.search.run()
+        return self.search.plot_2D('F0', 'F1', **kwargs)
 
 
 class GridGlitchSearch(GridSearch):