# Copyright (C) 2021 Xisco Jimenez Forteza
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
# Module to run PE on RD data 
import random
from multiprocessing import Pool
import dynesty
import numpy as np
import rdown 

class Ringdown_PE:
    def __init__(self,rdown_fun,data,dim,priors,errors2=1,theta=[],model='w-tau',norm_factor=0,l_int=0):
        self.dim = dim
        self.rdown_fun = rdown_fun
        self.times = data[:,0]
        self.datare = data[:,1].real
        self.dataim = data[:,1].imag
        self.priors = priors
        self.priors_min = priors[:,0]
        self.priors_max = priors[:,1]
        self.prior_dim = len(priors)
        self.errors2 = errors2
        self.norm_factor = norm_factor
        self.model = model
        self.l_int = l_int
        self.theta = theta
        self.dict = {'w-tau':rdown_fun.rd_model_wtau , 'w-q': rdown_fun.rd_model_wq, 'w-tau-fixed':rdown_fun.rd_model_wtau_fixed,'w-tau-fixed-m-af': rdown_fun.rd_model_wtau_m_af}

    #def log_likelihood(self,theta,sigma=1):
    #    """chi2 likelihood.
    #    """ 
    #    modelev = dict[model](theta)
    #    result = -np.sum(((gwdatanew_re_tsh - modelev.real)**2+(gwdatanew_im_tsh - modelev.imag)**2)/(2*theta[-1]*error_final))
    #    if np.isnan(result):
    #        return -np.inf
    #    return result
    
    def log_likelihood(self,theta,sigma=1):
        """chi2 likelihood.
        """ 
        modelev = self.dict[self.model](theta)
        modelevre= modelev.real
        modelevim= modelev.imag

        sigma2 = self.errors2 + self.l_int*(self.datare** 2+self.dataim**2) * np.exp(2 * theta[-1])

        result = -0.5*np.sum(((self.datare - modelevre)**2+(self.dataim - modelevim)**2)/sigma2+self.l_int*(2*np.log(sigma2)))-self.l_int*self.norm_factor

        if np.isnan(result):
            return -np.inf
        return result
      

    def prior_transform(self,cube):
        """RD uniform priors. The values for priors_min and priors_max must be given out of this function.
        """ 
        for i in range(self.prior_dim):
            cube[i] =  self.priors_min[i]+ cube[i]*(self.priors_max[i]-self.priors_min[i])
        return cube

def load_priors(model,config_parser,nmax,fitnoise=True):
    # loading priors
    if model == 'w-q':
        tau_var_str='q'
    else:
        tau_var_str='tau'
    
    if model == 'w-tau':
        w_mins=np.empty(nmax+1)
        w_maxs=np.empty(nmax+1)
        tau_mins=np.empty(nmax+1)
        tau_maxs=np.empty(nmax+1)
        a_mins=np.empty(nmax+1)
        a_maxs=np.empty(nmax+1)
        ph_mins=np.empty(nmax+1)
        ph_maxs=np.empty(nmax+1)

        for i in range(nmax+1): 
            wp_min=config_parser.get('prior-w'+str(i),'w'+str(i)+'_min')
            w_mins[i] = np.float(wp_min)

            wp_max=config_parser.get('prior-w'+str(i),'w'+str(i)+'_max')
            w_maxs[i] = np.float(wp_max)

            taup_min=config_parser.get('prior-'+tau_var_str+str(i),tau_var_str+str(i)+'_min')
            tau_mins[i] = np.float(taup_min)

            taup_max=config_parser.get('prior-'+tau_var_str+str(i),tau_var_str+str(i)+'_max')
            tau_maxs[i] = np.float(taup_max)

            amp0_min=config_parser.get('prior-amp'+str(i),'amp'+str(i)+'_min')
            a_mins[i] = np.float(amp0_min)

            amp1_max=config_parser.get('prior-amp'+str(i),'amp'+str(i)+'_max')
            a_maxs[i] = np.float(amp1_max)

            phase_min=config_parser.get('prior-phase'+str(i),'phase'+str(i)+'_min')
            ph_mins[i] = np.float(phase_min)*2*np.pi

            phase_max=config_parser.get('prior-phase'+str(i),'phase'+str(i)+'_max')
            ph_maxs[i] = np.float(phase_max)*2*np.pi

        priors_min = np.concatenate((w_mins,tau_mins,a_mins,ph_mins))
        priors_max = np.concatenate((w_maxs,tau_maxs,a_maxs,ph_maxs))
        prior_dim = len(priors_min)
        priors=np.column_stack((priors_min,priors_max))

    if model == 'w-tau-fixed':
        a_mins=np.empty(nmax+1)
        a_maxs=np.empty(nmax+1)
        ph_mins=np.empty(nmax+1)
        ph_maxs=np.empty(nmax+1)

        for i in range(nmax+1): 
            amp0_min=config_parser.get('prior-amp'+str(i),'amp'+str(i)+'_min')
            a_mins[i] = np.float(amp0_min)

            amp1_max=config_parser.get('prior-amp'+str(i),'amp'+str(i)+'_max')
            a_maxs[i] = np.float(amp1_max)

            phase_min=config_parser.get('prior-phase'+str(i),'phase'+str(i)+'_min')
            ph_mins[i] = np.float(phase_min)*2*np.pi

            phase_max=config_parser.get('prior-phase'+str(i),'phase'+str(i)+'_max')
            ph_maxs[i] = np.float(phase_max)*2*np.pi


        priors_min = np.concatenate((a_mins,ph_mins))
        priors_max = np.concatenate((a_maxs,ph_maxs))
        prior_dim = len(priors_min)
        priors=np.column_stack((priors_min,priors_max))

    elif model ==  'w-tau-fixed-m-af':
        a_mins=np.empty(nmax+1)
        a_maxs=np.empty(nmax+1)
        ph_mins=np.empty(nmax+1)
        ph_maxs=np.empty(nmax+1)

        for i in range(nmax+1): 
            amp0_min=config_parser.get('prior-amp'+str(i),'amp'+str(i)+'_min')
            a_mins[i] = np.float(amp0_min)

            amp1_max=config_parser.get('prior-amp'+str(i),'amp'+str(i)+'_max')
            a_maxs[i] = np.float(amp1_max)

            phase_min=config_parser.get('prior-phase'+str(i),'phase'+str(i)+'_min')
            ph_mins[i] = np.float(phase_min)*2*np.pi

            phase_max=config_parser.get('prior-phase'+str(i),'phase'+str(i)+'_max')
            ph_maxs[i] = np.float(phase_max)*2*np.pi

        mass_min=[np.float(config_parser.get('prior-mass','mass_min'))]
        mass_max=[np.float(config_parser.get('prior-mass','mass_max'))]
        spin_min=[np.float(config_parser.get('prior-spin','spin_min'))]
        spin_max=[np.float(config_parser.get('prior-spin','spin_max'))]
        priors_min = np.concatenate((a_mins,ph_mins,mass_min,spin_min))
        priors_max = np.concatenate((a_maxs,ph_maxs,mass_max,spin_max))
        prior_dim = len(priors_min)
        priors=np.column_stack((priors_min,priors_max))


    if fitnoise:
        priors_fit_min=[np.float(config_parser.get('prior-noise','noise_min'))]
        priors_fit_max=[np.float(config_parser.get('prior-noise','noise_max'))]
        priors_min = np.concatenate((priors_min,priors_fit_min))
        priors_max = np.concatenate((priors_max,priors_fit_max))
        priors=np.column_stack((priors_min,priors_max))
        prior_dim = len(priors_min)

    return priors