diff --git a/prior_wrapper.py b/prior_wrapper.py
index 06f40379d49e9a94419ad94ea1bfba3c835754d3..a2579823efcfa2c16f20e6d09e6e56fae5f00bad 100644
--- a/prior_wrapper.py
+++ b/prior_wrapper.py
@@ -1,12 +1,78 @@
+# prior_wrapper.py
+"""Classes and functions that can wrap an Enterprise pta object for HBMs
+
+Simple usage instructions
+-------------------------
+Build your Enterprise PTA object as you normally would. Make sure you know what
+the parameter names are for the signals that you want to place a hierarchical
+prior on. Those parameter names will be matched with a regular expression.
+For instance, for powerlaw red noise, often the parameters would be named:
+'J0030+0451_red_noise_log10_A'
+'J0030+0451_red_noise_gamma'
+'J1713+0747_red_noise_log10_A',
+...
+
+With the above naming convention, you could create an HBM like this:
+
+import prior_wrapper as pw
+
+wrapper = pw.EnterpriseWrapper(
+    pta=pta,
+    hyper_regexps = {
+        'red_noise': {
+            'log10_amp': '_red_noise_log10_A$',
+            'gamma': '_red_noise_gamma$',
+            'prior': pw.BoundedMvNormalPlHierarchicalPrior,
+        }
+    }
+)
+
+x0 = wrapper.sample()
+cov = np.diag(0.01 * np.ones_like(x0))
+
+sampler = ptmcmc(len(x0), wrapper.log_likelihood, wrapper.log_prior, cov, outDir='./chains', resume=False)
+
+# parameter names are in:
+# wrapper.param_names
+
+# Add the prior draws
+for draw_function in wrapper.get_draw_from_prior_functions():
+    sampler.addProposalToCycle(draw_function, 5)
+
+
+# If you also want to model DMGP with a Hierarchical Prior, you could do:
+
+wrapper = pw.EnterpriseWrapper(
+    pta=pta,
+    hyper_regexps = {
+        'red_noise': {
+            'log10_amp': '_red_noise_log10_A$',
+            'gamma': '_red_noise_gamma$',
+            'prior': pw.BoundedMvNormalPlHierarchicalPrior,
+        },
+        'dmgp': {
+            'log10_amp': '_dm_gp_log10_A$',
+            'gamma': '_dm_gp_gamma$',
+            'prior': pw.BoundedMvNormalPlHierarchicalPrior,
+        },
+    }
+)
+
+# This last one is largely untested. Since the prior ranges for the
+# hyperparameters cannot be set dynamically at the moment, we may need to update
+# the code a bit to make it work for more general priors
+
+"""
+
 import numpy as np
 import scipy.stats as sstats
 import scipy.linalg as sl
 import re
+from enterprise.signals.parameter import function
 
 class kumaraswamy_distribution(sstats.rv_continuous):
     """Kumaraswamy distribution like for scipy"""
     def _pdf(self, x, a, b):
-        # Adding a condition to ensure x is within the [0,1] interval
         return np.where((x >= 0) & (x <= 1), a * b * x**(a-1) * (1 - x**a)**(b-1), 0)
 
     def _cdf(self, x, a, b):
@@ -107,31 +173,6 @@ class IntervalTransform(object):
         num = 1 + np.exp(p)
         return np.log(self._upper-self._lower) + np.log(1/num - 1/(num**2))
 
-def ptmcmc_to_inferencedata(chain, params):
-    """Inferencedata, used by Arviz, can be created from a chain like this"""
-    import arviz as az
-    import xarray as xr
-    samples = chain.reshape(1, -1, chain.shape[1])
-
-    datasets = {}
-    for i, name in enumerate(params):
-        data_array = xr.DataArray(samples[:, :, i], dims=("chain", "draw"), name=name)
-        datasets[name] = data_array
-    dataset = xr.merge(datasets.values())
-    return az.convert_to_inference_data(dataset)
-
-def inferencedata_to_ptmcmc(idata):
-    """InferenceData to ptmcmc chain"""
-
-    # Get parameter names
-    param_names = list(idata.posterior.data_vars)
-
-    # Get MCMC chains
-    posterior = idata.posterior
-    chains = {var_name: posterior[var_name].values for var_name in param_names}
-
-    return chains, param_names
-
 def log_sum_exp(log_prior1, log_prior2):
     """Take the log of two exponential sums, stable numerically"""
     max_log_prior = np.maximum(log_prior1, log_prior2)
@@ -166,6 +207,12 @@ def ptapar_mapping(pta):
 
     return ptapar_to_array, array_to_ptapar
 
+@function
+def powerlaw_flat_tail(f, log10_A=-16, gamma=5, log10_kappa=-7, components=2):
+    df = np.diff(np.concatenate((np.array([0]), f[::components])))
+    return (
+        (10**log10_A) ** 2 / 12.0 / np.pi**2 * const.fyr ** (gamma - 3) * f ** (-gamma) * np.repeat(df, components) + 10 ** (2*log10_kappa)
+    )
 
 class BoundedMvNormalPlHierarchicalPrior(object):
     """Class to represent a Bounded MvNormal hierarchical prior for Enterprise Powerlaw Signals"""
@@ -219,8 +266,8 @@ class BoundedMvNormalPlHierarchicalPrior(object):
 
     def set_hyperpriors(self):
         """Set the hyper parameter priors"""
-        self._mu_amp = sstats.uniform(loc=-20, scale=7)
-        #self._mu_gamma = sstats.uniform(loc=0, scale=7)
+        # TODO: allow for more flexible setting of hyperparameter ranges
+        self._mu_amp = sstats.uniform(loc=-20, scale=10)
         self._mu_gamma = sstats.uniform(loc=-4, scale=8)
         self._L_amp = sstats.uniform(loc=0.03, scale=3.47)
         self._L_gamma = sstats.uniform(loc=0.03, scale=3.47)