diff --git a/prior_wrapper.py b/prior_wrapper.py index 55d2adae9ccf442240cb6302a8ec01f6525b82fa..6107d405e32443aa5f4eb6647235b311a2504f1f 100644 --- a/prior_wrapper.py +++ b/prior_wrapper.py @@ -361,7 +361,10 @@ class BoundedMvNormalPlHierarchicalPrior(object): amps = p[self._la_inds] gammas = p[self._g_inds] pag = np.vstack([amps, gammas]) - uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True) + try: + uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True) + except sl.LinAlgError as e: + return -np.inf quad = -0.5 * np.sum(uag**2, axis=0) norm = - np.sum(np.log(np.diag(L))) - np.log(2*np.pi) @@ -406,7 +409,10 @@ class BoundedMvNormalPlHierarchicalPrior(object): amps = qp[self._la_inds] gammas = qp[self._g_inds] pag = np.vstack([amps, gammas]) - uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True) + try: + uag = sl.solve_triangular(L, pag - mu[:,None], trans=0, lower=True) + except sl.LinAlgError as e: + return x, 0 # Draw a random element from uag to update n_total = np.prod(uag.shape) @@ -518,14 +524,17 @@ class BoundedTwoComponentMvNormalPlHierarchicalPrior(BoundedMvNormalPlHierarchic pag = np.vstack([amps, gammas]) # Mode 1 & 2 Gaussian components - uag1 = sl.solve_triangular(L1, pag - mu1[:,None], trans=0, lower=True) - quad1 = -0.5 * np.sum(uag1**2, axis=0) - norm1 = - np.sum(np.log(np.diag(L1))) - np.log(2*np.pi) - uag2 = sl.solve_triangular(L2, pag - mu2[:,None], trans=0, lower=True) - quad2 = -0.5 * np.sum(uag2**2, axis=0) - norm2 = - np.sum(np.log(np.diag(L2))) - np.log(2*np.pi) - log_prior1 = np.sum(quad1 + norm1) - log_prior2 = np.sum(quad2 + norm2) + try: + uag1 = sl.solve_triangular(L1, pag - mu1[:,None], trans=0, lower=True) + quad1 = -0.5 * np.sum(uag1**2, axis=0) + norm1 = - np.sum(np.log(np.diag(L1))) - np.log(2*np.pi) + uag2 = sl.solve_triangular(L2, pag - mu2[:,None], trans=0, lower=True) + quad2 = -0.5 * np.sum(uag2**2, axis=0) + norm2 = - np.sum(np.log(np.diag(L2))) - np.log(2*np.pi) + log_prior1 = np.sum(quad1 + norm1) + log_prior2 = np.sum(quad2 + norm2) + except sl.LinAlgError as e: + return -np.inf log_prior = log_weighted_sum_exp(log_prior1, log_prior2, CF) log_jacobian = self.log_dpdx(x) @@ -569,6 +578,11 @@ class EnterpriseWrapper(object): self._ndim = self._ndim_level1 + self._ndim_level2 self._ptapar_to_array, self._array_to_ptapar = ptapar_mapping(self._pta) + # Initialize all the Enterprise prior distributions for efficiency + self._nohbm_indices = self.get_nohbm_indices() + nohbm_parameter_indices = list(set(self._array_to_ptapar[self._nohbm_indices])) + self._nohbm_parameters = [self._pta.params[pp] for pp in nohbm_parameter_indices] + @property def param_names(self): """All parameter names of whole HBM""" @@ -629,7 +643,8 @@ class EnterpriseWrapper(object): def log_prior(self, x): """Full hierarchical log-prior""" - logp = np.sum([self._pta.params[self._array_to_ptapar[ii]].get_logpdf(x[ii]) for ii in self.get_nohbm_indices()]) + params = self._pta.map_params(self.get_low_level_pars(x)) + logp = np.sum([p.get_logpdf(params=params) for p in self._nohbm_parameters]) for prior in self.hyper_priors: logp += prior.log_prior(x)