diff --git a/prior_wrapper.py b/prior_wrapper.py
index 6107d405e32443aa5f4eb6647235b311a2504f1f..41b0d6ce14605ed0f3f959ec2634dfebdc17f5a9 100644
--- a/prior_wrapper.py
+++ b/prior_wrapper.py
@@ -70,6 +70,7 @@ import scipy.linalg as sl
 import re
 from enterprise.signals.parameter import function
 import enterprise.constants as const
+import enterprise.signals.gp_priors as gp_priors
 
 class kumaraswamy_distribution(sstats.rv_continuous):
     """Kumaraswamy distribution like for scipy"""
@@ -211,9 +212,9 @@ def ptapar_mapping(pta):
 @function
 def powerlaw_flat_tail(f, log10_A=-16, gamma=5, log10_kappa=-7, components=2):
     df = np.diff(np.concatenate((np.array([0]), f[::components])))
-    return (
-        (10**log10_A) ** 2 / 12.0 / np.pi**2 * const.fyr ** (gamma - 3) * f ** (-gamma) * np.repeat(df, components) + 10 ** (2*log10_kappa)
-    )
+    pl = (10**log10_A) ** 2 / 12.0 / np.pi**2 * const.fyr ** (gamma - 3) * f ** (-gamma) * np.repeat(df, components)
+    flat = 10 ** (2*log10_kappa)
+    return np.maximum(pl, flat)
 
 class BoundedMvNormalPlHierarchicalPrior(object):
     """Class to represent a Bounded MvNormal hierarchical prior for Enterprise Powerlaw Signals"""