import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from lal import ComputeDetAMResponseExtraModes, GreenwichMeanSiderealTime, LIGOTimeGPS
from scipy.linalg import inv, toeplitz, cholesky, solve_toeplitz
from tqdm import tqdm
def project(hs,
            hvx,
            hvy,
            hp,
            hc,
            detector,
            ra,
            dec,
            psi,
            tgps):

    #cdef double gmst, fs, fvx, fvy, fp, fc
    gmst   = GreenwichMeanSiderealTime(tgps)
    #The breathing and longitudinal modes act on a L-shaped detector in the same way up to a constant amplitude, 
    # thus we just use one. See Isi-Weinstein, arxiv 1710.03794
    fp, fc, fb, fs, fvx, fvy = ComputeDetAMResponseExtraModes(detector.response, ra, dec, psi, gmst)
    waveform = fs*hs + fvx*hvx + fvy*hvy + fp*hp + fc*hc

    return waveform
def inner_product(C,X):
    return np.dot(X,np.dot(C,X))
def loglikelihood_core(residuals,inverse_covariance,log_normalisation):
    return -0.5*inner_product(inverse_covariance, residuals) + log_normalisation
def local_loglikelihood(model,
                  x, #dict of par
                  waveform_model,
                  ra,
                  dec,
                  psi,
                  t_start,
                  time_delay,
                  ref_det,
                  truncate,
                  duration_n):

    logL = 0.0
    for d in model.detectors.keys():

        # Waveform starts at time 0, so we need the 0 to be at model.tevent+dt
        # Sample times for each detector are: d.time-(model.tevent + dt)
        dt   = time_delay['{0}_'.format(ref_det)+d]
        tref = LIGOTimeGPS(t_start+dt+model.tevent)

        if not truncate:
            time_array     = model.detectors[d].time - (model.tevent+dt)
            data           = model.detectors[d].time_series
        else:
            # crop data
            time_array_raw = model.detectors[d].time - (model.tevent+dt)
            time_array     = time_array_raw[time_array_raw >= t_start][:duration_n]
            data           = model.detectors[d].time_series[time_array_raw >= t_start][:duration_n]

        if waveform_model is not None:

            wf_model             = waveform_model.waveform(time_array)
            hs, hvx, hvy, hp, hc = wf_model[0], wf_model[1], wf_model[2], wf_model[3], wf_model[4]
            residuals            = \
               data - project(hs, hvx, hvy, hp, hc, model.detectors[d].lal_detector, ra, dec, psi, tref)
        else:
            residuals = data


        logL += loglikelihood_core(residuals, model.detectors[d].inverse_covariance, \
                                   model.detectors[d].log_normalisation)

    return logL


class wheel():
    def __init__(self,model,fit=None,acf_from_psd=None):
        '''A wheel to reproduce the PyRing loglikelihood and SNR computation
        
        Parameter:
        ----------
        model: pyRing.KerrModel
            The KerrModel for pyring

        '''
        self.data = {}
        self.cinv = {}
        self.acf = {}
        self.acf_from_psd = {}
        self.cinv_acf_from_psd = {}
        self.cinv_acf_from_ringdown = {}
        self.L = {}
        self.white_d = {}
        self.have_run_model = 0
        
        #To activate model.time_delay we need to run loglikelihood first
        pyring_par = {'Mf': 72.84400244,
                     'af': 0.72848178,
                     'A2220': 5.595449455544791,
                     'A2221': 6.457081321435618,
                     'phi2220': -0.9737754582321143,
                     'phi2221': 1.652119726595113}
        _ = model.log_likelihood(pyring_par)
        t_start = model.fixed_params['t']
        duration_n = model.duration_n
        
        for d in model.detectors.keys():
            dt   = model.time_delay['{0}_'.format(model.ref_det)+d]
            tref = LIGOTimeGPS(t_start+dt+model.tevent)
            # crop data
            time_array_raw = model.detectors[d].time - (model.tevent+dt)
            time_array     = time_array_raw[time_array_raw >= t_start][:duration_n]
            self.data[d]   = model.detectors[d].time_series[time_array_raw >= t_start][:duration_n]
            self.cinv[d]   = model.detectors[d].inverse_covariance
            # read in the acf
            if acf_from_psd:
                self.acf[d] = np.loadtxt('/work/yifan.wang/ringdown/GW150914/pyring/compare-pyring-ringdown-pycbc/GW150914_PROD1_Kerr_221_0M/Noise/ACF_TD_cropped_'+str(d)+'_1126257414_4096_4.0_2048_0.2.txt')
                psd = np.loadtxt('/work/yifan.wang/ringdown/GW150914/pyring/compare-pyring-ringdown-pycbc/GW150914_PROD1_Kerr_221_0M/Noise/PSD_'+str(d)+'_1126257414_4096_4.0_2048.txt')
                acf_psd = 0.5*np.fft.irfft(psd[:,1]) * self.model.srate
                c = toeplitz(acf_psd[:model.duration_n])
                self.cinv_acf_from_psd[d] = inv(c)
            if fit is not None:
                acf = fit.acfs[d].values[:model.duration_n]
                c = toeplitz(acf[:model.duration_n])
                self.cinv_acf_from_ringdown[d] = inv(c)
        
        self.model = model
        
    def get_data(self,detector_time,detector_rawdata):
        '''Get the date corredsponding to waveform after correcting the labeling issue
        
        Parameters:
        -----------
        detector_time: dict
            By default, 4s duration
        detector_rawdata: dict
            By default, 4s data
            
        Return:
        -----------
        cropped_time: dict
            cropped time duration corredspoinding to waveform
        cropeed_data: dict
            cropped data corresponding to waveform
        '''
        cropped_time = {}
        cropped_data = {}
        
        t_start = self.model.fixed_params['t']
        duration_n = self.model.duration_n
        
        for d in self.model.detectors.keys():
            dt   = self.model.time_delay['{0}_'.format(self.model.ref_det)+d]
            tref = LIGOTimeGPS(t_start+dt+self.model.tevent)
            
            # crop data
            lcrop = detector_time[d] >= self.model.tevent + dt + t_start
            cropped_time[d] = detector_time[d][lcrop][:duration_n]
            cropped_data[d] = detector_rawdata[d][lcrop][:duration_n]
        return cropped_time,cropped_data
    
    def get_hstrain(self,pyring_par,detector_time):
        '''Compute the waveform given parameters
        
        Parameters:
        -----------
        pyring_par: dict
            waveform parameters
        detector_time: dict
            updated detector time after correcting the labeling issue for pyring
        '''
        h = {}
        waveform_time = {}
        
        #To activate model.time_delay we need to run loglikelihood first
        _ = self.model.log_likelihood(pyring_par)
        waveform_model = self.model.get_waveform(pyring_par)
        
        t_start = self.model.fixed_params['t']
        ra = self.model.fixed_params['ra']
        dec = self.model.fixed_params['dec']
        psi = self.model.fixed_params['psi']
        duration_n = self.model.duration_n
    
        for d in self.model.detectors.keys():
            dt   = self.model.time_delay['{0}_'.format(self.model.ref_det)+d]
            tref = LIGOTimeGPS(t_start+dt+self.model.tevent)
            
            # crop data
            #time_array_raw = self.model.detectors[d].time - (self.model.tevent+dt)
            time_array_raw = detector_time[d] - (self.model.tevent+dt)
            time_array     = time_array_raw[time_array_raw >= t_start][:duration_n]
            
            wf_model = waveform_model.waveform(time_array)
            hs, hvx, hvy, hp, hc = wf_model[0], wf_model[1], wf_model[2], wf_model[3], wf_model[4]
            h[d]  = project(hs, hvx, hvy, hp, hc, self.model.detectors[d].lal_detector, ra, dec, psi, tref)
            waveform_time[d] = time_array + self.model.tevent + dt
        return waveform_time,h
    
    def logL(self):
        loglikelihood = 0
        for d in model.detectors.keys():
            residuals = self.data[d] - self.hstrain[d]
            loglikelihood += loglikelihood_core(residuals, self.cinv[d],self.model.detectors[d].log_normalisation)
        return loglikelihood
    
    def optsnr(self,pyring_par,rawtime,network=True,acf_from_psd=False,acf_from_ringdown=False):
        '''
        Optimal SNR
        
        Parameters:
        network: bool
            If true, return network SNR
        use_rd_l: bool
            If true, use the L from Ringdown, where C=LL^T (cholesky decomposition)
        -----------
        '''    
        _, h = self.get_hstrain(pyring_par,rawtime)
        snr = {}
        
        for d in self.model.detectors.keys():
            if acf_from_psd == True:
                cinv = self.cinv_acf_from_psd[d]
            elif acf_from_ringdown == True:
                cinv = self.cinv_acf_from_ringdown[d]
            else:
                cinv = self.cinv[d]
            snr[d] = np.dot(h[d],np.dot(cinv,h[d]))
            
        if network:
            result = 0
            for d in self.model.detectors.keys():
                result += snr[d]
            return np.sqrt(result)
        else:
            for d in self.model.detectors.keys():
                snr[d] = np.sqrt(snr[d]) 
            return snr
    
    def mfsnr(self,pyring_par,rawtime,rawdata,network=True,acf_from_psd=False,acf_from_ringdown=False):
        '''Matched-filter SNR
        
        Parameters:
        -----------
        pyring_par: dict
            Source parameters
        rawtime: dict
            The time duration for one chunk of signal
        rawdata: dict
            The data for one chunk of signal
        '''    
        _, h = self.get_hstrain(pyring_par,rawtime)
        _, data = self.get_data(rawtime,rawdata)
                             
        snr = {}
        for d in self.model.detectors.keys():
            if acf_from_psd:
                cinv = self.cinv_acf_from_psd[d]
            elif acf_from_ringdown:
                cinv = self.cinv_acf_from_ringdown[d]
            else:
                cinv = self.cinv[d]
            snr[d] = np.dot(data[d],np.dot(cinv,h[d]))**2 \
                        /np.dot(h[d],np.dot(cinv,h[d]))

        if network:
            result = 0
            for d in self.model.detectors.keys():
                result += snr[d]
            return np.sqrt(result)
        else:
            for d in self.model.detectors.keys():
                snr[d] = np.sqrt(snr[d]) 
            return snr
    
    def optsnr_cholesky(self):
        '''
        Compute the optimal SNR with Cholesky decomposition
        
        Parameters:
        -----------
        '''
        snr = 0 
        for d in model.detectors.keys():
            self.L[d] = np.linalg.cholesky(self.c[d])
            self.white_d[d] = np.linalg.solve(self.L[d], self.data[d])
            white_h = np.linalg.solve(self.L[d], self.hstrain[d])
            snr += np.sum(white_h**2)
        return np.sqrt(snr)

    def mfsnr_cholesky(self):
        '''
        Compute the matched-filter SNR with Cholesky decomposition
        
        Parameters:
        -----------
        '''
        snr = 0 
        for d in model.detectors.keys():
            self.L[d] = np.linalg.cholesky(self.c[d])
            self.white_d[d] = np.linalg.solve(self.L[d], self.data[d])
            white_h = np.linalg.solve(self.L[d], self.hstrain[d])
            snr += np.sum(self.white_d[d]*white_h)**2  / np.sum(white_h*white_h)
        return np.sqrt(snr)
    
    def optsnr_solvetoe(self):
        '''
        Compute the optimal SNR with Solve Toeplitz Method
        
        Parameters:
        -----------
        '''
        snr = 0 
        for d in model.detectors.keys():
            snr += np.dot(self.hstrain[d],solve_toeplitz(self.acf[d],self.hstrain[d]))
        return np.sqrt(snr)
    
    def mfsnr_solvetoe(self):
        '''
        Compute the optimal SNR with Solve Toeplitz Method
        
        Parameters:
        -----------
        '''
        snr = 0 
        for d in model.detectors.keys():
            snr += np.dot(self.data[d],solve_toeplitz(self.acf[d],self.hstrain[d]))**2 \
                        /np.dot(self.hstrain[d],solve_toeplitz(self.acf[d],self.hstrain[d]))
        return np.sqrt(snr)

def compute_multiple_snr(model,pr_time,pr_data,M,chi,A,phi,
                       fit=None,network=True,acf_from_psd=False,acf_from_ringdown=False):
    '''
    Loop compute the optimal SNR and matched-filter SNR given a PyRing Model
    and parameters from Ringdown
    
    Output:
    -------
    Optimal SNR, Matched-Filtering SNR: np.array()
    '''

    #Initialization
    if network:
        optsnr = []
        mfsnr = [] 
    else:
        optsnr = {}
        mfsnr = {}
        for d in model.detectors.keys():
            optsnr[d] = []
            mfsnr[d] = []
    result = wheel(model,fit)        
    for i in tqdm(range(4000)):
        prefactor = np.sqrt(16*np.pi/5)
        pyring_par = {'Mf': M[i].values,
                  'af': chi[i].values,
                  'A2220': A[0][i].values/1e-21*prefactor,
                  'A2221': A[1][i].values/1e-21*prefactor,
                  'phi2220': -phi[0][i].values,
                  'phi2221': -phi[1][i].values}

        if network:
            optsnr.append(result.optsnr(pyring_par,pr_time,network,acf_from_psd,acf_from_ringdown))
            mfsnr.append(result.mfsnr(pyring_par,pr_time,pr_data,network,acf_from_psd,acf_from_ringdown))
        else:
            for d in model.detectors.keys():
                optsnr[d].append(result.optsnr(pyring_par,pr_data,network,acf_from_psd,acf_from_ringdown)[d])
                mfsnr[d].append(result.mfsnr(pyring_par,pr_time,pr_data,network,acf_from_psd,acf_from_ringdown)[d])

    return optsnr,mfsnr

def plotsnr(rdsnr,prsnr,snr='Optimal SNR'):
    fig_width_pt = 3*246.0  # Get this from LaTeX using \showthe\columnwidth
    inches_per_pt = 1.0/72.27               # Convert pt to inch
    golden_mean = (np.sqrt(5)-1.0)/2.0         # Aesthetic ratio
    fig_width = fig_width_pt*inches_per_pt  # width in inches
    fig_height = fig_width*golden_mean      # height in inches
    fig_size =  [fig_width,fig_height]
    params = { 'axes.labelsize': 24,
              'font.family': 'serif',
              'font.serif': 'Computer Modern Raman',
              'font.size': 24,
              'legend.fontsize': 20,
              'xtick.labelsize': 24,
              'ytick.labelsize': 24,
              'axes.grid' : True,
              'text.usetex': True,
              'savefig.dpi' : 100,
              'legend.frameon': True,
              'legend.loc': 'best',
              'lines.markersize' : 14,
              'figure.figsize': fig_size}
    mpl.rcParams.update(params)

    plt.figure(figsize=[16,10])
    plt.subplot(211)
    plt.plot(rdsnr,label='Ringdown '+snr)
    plt.plot(prsnr,alpha=0.5,label='PyRing '+snr)
    plt.title(snr,fontsize=20)
    plt.legend(loc='best',ncol=2)
    plt.subplot(212)
    plt.plot(rdsnr - prsnr,color='grey')
    plt.ylabel('Ringdown SNR - PyRing SNR')
    plt.xlabel('Number of samples')