Skip to content
Snippets Groups Projects
Commit 18eeaf21 authored by Rutger van Haasteren's avatar Rutger van Haasteren
Browse files

Sped up the prior by 30x

parent 304d84ed
No related branches found
No related tags found
No related merge requests found
...@@ -361,7 +361,10 @@ class BoundedMvNormalPlHierarchicalPrior(object): ...@@ -361,7 +361,10 @@ class BoundedMvNormalPlHierarchicalPrior(object):
amps = p[self._la_inds] amps = p[self._la_inds]
gammas = p[self._g_inds] gammas = p[self._g_inds]
pag = np.vstack([amps, gammas]) pag = np.vstack([amps, gammas])
try:
uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True) uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True)
except sl.LinAlgError as e:
return -np.inf
quad = -0.5 * np.sum(uag**2, axis=0) quad = -0.5 * np.sum(uag**2, axis=0)
norm = - np.sum(np.log(np.diag(L))) - np.log(2*np.pi) norm = - np.sum(np.log(np.diag(L))) - np.log(2*np.pi)
...@@ -406,7 +409,10 @@ class BoundedMvNormalPlHierarchicalPrior(object): ...@@ -406,7 +409,10 @@ class BoundedMvNormalPlHierarchicalPrior(object):
amps = qp[self._la_inds] amps = qp[self._la_inds]
gammas = qp[self._g_inds] gammas = qp[self._g_inds]
pag = np.vstack([amps, gammas]) pag = np.vstack([amps, gammas])
try:
uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True) uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True)
except sl.LinAlgError as e:
return x, 0
# Draw a random element from uag to update # Draw a random element from uag to update
n_total = np.prod(uag.shape) n_total = np.prod(uag.shape)
...@@ -518,6 +524,7 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic ...@@ -518,6 +524,7 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic
pag = np.vstack([amps, gammas]) pag = np.vstack([amps, gammas])
# Mode 1 & 2 Gaussian components # Mode 1 & 2 Gaussian components
try:
uag1 = sl.solve_triangular(L1, pag - mu1[:,None], trans=0, lower=True) uag1 = sl.solve_triangular(L1, pag - mu1[:,None], trans=0, lower=True)
quad1 = -0.5 * np.sum(uag1**2, axis=0) quad1 = -0.5 * np.sum(uag1**2, axis=0)
norm1 = - np.sum(np.log(np.diag(L1))) - np.log(2*np.pi) norm1 = - np.sum(np.log(np.diag(L1))) - np.log(2*np.pi)
...@@ -526,6 +533,8 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic ...@@ -526,6 +533,8 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic
norm2 = - np.sum(np.log(np.diag(L2))) - np.log(2*np.pi) norm2 = - np.sum(np.log(np.diag(L2))) - np.log(2*np.pi)
log_prior1 = np.sum(quad1 + norm1) log_prior1 = np.sum(quad1 + norm1)
log_prior2 = np.sum(quad2 + norm2) log_prior2 = np.sum(quad2 + norm2)
except sl.LinAlgError as e:
return -np.inf
log_prior = log_weighted_sum_exp(log_prior1, log_prior2, CF) log_prior = log_weighted_sum_exp(log_prior1, log_prior2, CF)
log_jacobian = self.log_dpdx(x) log_jacobian = self.log_dpdx(x)
...@@ -569,6 +578,11 @@ class EnterpriseWrapper(object): ...@@ -569,6 +578,11 @@ class EnterpriseWrapper(object):
self._ndim = self._ndim_level1 + self._ndim_level2 self._ndim = self._ndim_level1 + self._ndim_level2
self._ptapar_to_array, self._array_to_ptapar = ptapar_mapping(self._pta) self._ptapar_to_array, self._array_to_ptapar = ptapar_mapping(self._pta)
# Initialize all the Enterprise prior distributions for efficiency
self._nohbm_indices = self.get_nohbm_indices()
nohbm_parameter_indices = list(set(self._array_to_ptapar[self._nohbm_indices]))
self._nohbm_parameters = [self._pta.params[pp] for pp in nohbm_parameter_indices]
@property @property
def param_names(self): def param_names(self):
"""All parameter names of whole HBM""" """All parameter names of whole HBM"""
...@@ -629,7 +643,8 @@ class EnterpriseWrapper(object): ...@@ -629,7 +643,8 @@ class EnterpriseWrapper(object):
def log_prior(self, x): def log_prior(self, x):
"""Full hierarchical log-prior""" """Full hierarchical log-prior"""
logp = np.sum([self._pta.params[self._array_to_ptapar[ii]].get_logpdf(x[ii]) for ii in self.get_nohbm_indices()]) params = self._pta.map_params(self.get_low_level_pars(x))
logp = np.sum([p.get_logpdf(params=params) for p in self._nohbm_parameters])
for prior in self.hyper_priors: for prior in self.hyper_priors:
logp += prior.log_prior(x) logp += prior.log_prior(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment