Skip to content
Snippets Groups Projects
Commit 18eeaf21 authored by Rutger van Haasteren's avatar Rutger van Haasteren
Browse files

Sped up the prior by 30x

parent 304d84ed
No related branches found
No related tags found
No related merge requests found
......@@ -361,7 +361,10 @@ class BoundedMvNormalPlHierarchicalPrior(object):
amps = p[self._la_inds]
gammas = p[self._g_inds]
pag = np.vstack([amps, gammas])
try:
uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True)
except sl.LinAlgError as e:
return -np.inf
quad = -0.5 * np.sum(uag**2, axis=0)
norm = - np.sum(np.log(np.diag(L))) - np.log(2*np.pi)
......@@ -406,7 +409,10 @@ class BoundedMvNormalPlHierarchicalPrior(object):
amps = qp[self._la_inds]
gammas = qp[self._g_inds]
pag = np.vstack([amps, gammas])
try:
uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True)
except sl.LinAlgError as e:
return x, 0
# Draw a random element from uag to update
n_total = np.prod(uag.shape)
......@@ -518,6 +524,7 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic
pag = np.vstack([amps, gammas])
# Mode 1 & 2 Gaussian components
try:
uag1 = sl.solve_triangular(L1, pag - mu1[:,None], trans=0, lower=True)
quad1 = -0.5 * np.sum(uag1**2, axis=0)
norm1 = - np.sum(np.log(np.diag(L1))) - np.log(2*np.pi)
......@@ -526,6 +533,8 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic
norm2 = - np.sum(np.log(np.diag(L2))) - np.log(2*np.pi)
log_prior1 = np.sum(quad1 + norm1)
log_prior2 = np.sum(quad2 + norm2)
except sl.LinAlgError as e:
return -np.inf
log_prior = log_weighted_sum_exp(log_prior1, log_prior2, CF)
log_jacobian = self.log_dpdx(x)
......@@ -569,6 +578,11 @@ class EnterpriseWrapper(object):
self._ndim = self._ndim_level1 + self._ndim_level2
self._ptapar_to_array, self._array_to_ptapar = ptapar_mapping(self._pta)
# Initialize all the Enterprise prior distributions for efficiency
self._nohbm_indices = self.get_nohbm_indices()
nohbm_parameter_indices = list(set(self._array_to_ptapar[self._nohbm_indices]))
self._nohbm_parameters = [self._pta.params[pp] for pp in nohbm_parameter_indices]
@property
def param_names(self):
"""All parameter names of whole HBM"""
......@@ -629,7 +643,8 @@ class EnterpriseWrapper(object):
def log_prior(self, x):
"""Full hierarchical log-prior"""
logp = np.sum([self._pta.params[self._array_to_ptapar[ii]].get_logpdf(x[ii]) for ii in self.get_nohbm_indices()])
params = self._pta.map_params(self.get_low_level_pars(x))
logp = np.sum([p.get_logpdf(params=params) for p in self._nohbm_parameters])
for prior in self.hyper_priors:
logp += prior.log_prior(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment