# Copyright (C) 2021 Xisco Jimenez Forteza
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
# Module to run PE on RD data 
import numpy as np
from dynesty.utils import resample_equal
from dynesty import utils as dyfunc
import os
import csv
import pandas as pd
import pickle

def posterior_samples(sampler):
    """
    Returns posterior samples from nested samples and weights
    given by dynsety sampler
    """

    dynesty_samples = sampler.results['samples']
    wt = np.exp(sampler.results['logwt'] -
                       sampler.results['logz'][-1])
    # Make sure that sum of weights equal to 1
    weights = wt/np.sum(wt)
    posterior_dynesty = dyfunc.resample_equal(dynesty_samples, weights)
    return posterior_dynesty

def FFT_FreqBins(times):
    Len = len(times)
    DeltaT = times[-1]- times[0]
    dt = DeltaT/(Len-1)
    dnu = 1/(Len*dt)
    maxfreq = 1/(2*dt)
    add = dnu/4

    p = np.arange(0.0,maxfreq+add,dnu)
    m = np.arange(p[-1]-(2*maxfreq)+dnu,-dnu/2+add,dnu)
    res=np.concatenate((p,m))
    
    return res

def hFromPsi4FFI(tpsi4,f0):
    
    timecheck1=tpsi4[-2,0]-tpsi4[-1,0]
    timecheck2=tpsi4[1,0]-tpsi4[0,0]
    
    if np.abs(timecheck1-timecheck2)>=0.0001:
        print("The data might not be equally sampled!!")

    times,data= tpsi4[:,0],tpsi4[:,1]

    freqs = FT_FreqBins(xaxis.real).real
    position = np.argmax(freqs >= f0)
    freqs[:position]=f0*np.ones(len(freqs[:position]))
    freqs=2*np.pi*freqs

    fdata=fft(data)
    len(myTable)*ifft(- fdata/floor**2);
    np.stack((times,data)).T

    
def twopoint_autocovariance(t,n):
    """ It computes the two-point autocovariance function.
    """  
    dt=t[1]-t[0]
    res = np.zeros(len(n))
    taus = np.zeros(len(n))
    for tau in range(0,int(len(n)/2)):
        ntau=np.roll(n, tau)
        taus[tau] = t[tau]
        res[tau]=np.sum(n*ntau).real
    return (taus[:int(len(n)/2)],res[:int(len(n)/2)])

def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)


def EasyMatchT(t,h1,h2,tmin,tmax):
    """ It computes the time-domain match for (h1|h2)  complex waveforms.
    """
    pos = np.argmax(t >= (tmin));
    
    h1red=h1[pos:];
    h2red=h2[pos:];
    
    norm1=np.sum(np.abs(h1red)**2)
    norm2=np.sum(np.abs(h2red)**2)

    myTable=h1red*np.conjugate(h2red)
    res=((np.sum(myTable)/np.sqrt(norm1*norm2))).real
    
    return res

def EasySNRT(t,h1,h2,tmin,tmax):
    """ It computes the time-domain snr for (h1|h2)  complex waveforms.
    """    
    pos = np.argmax(t >= (tmin));
    
    h1red=h1[pos:];
    h2red=h2[pos:];

    myTable=h1red*np.conjugate(h2red)
    res=2*np.sqrt((np.sum(myTable)).real)
    
    return res
    
def FindTmaximum(y):
    """ It determines the maximum absolute value of the complex waveform.
    """
    absval = np.sqrt(y[:,1]*y[:,1]+y[:,2]*y[:,2])
    vmax=np.max(absval)
    index = np.argmax(absval == vmax)
    timemax=y[index,0]
    
    return timemax

def export_logz_files(output_file,pars):
    sim_num, nmax, tshift, evidence, evidence_error  = pars
    
    """
    Generate the logz.csv files you want to export the data to. 
    file_type must be one of this options: [corner_plot,corner_plot_extra,diagnosis,fit,post_samples,sampler_results,log_z]
    """
    
    summary_titles=['n','id','t_shift','dlogz','dlogz_err']
    if os.path.exists(output_file):
        outvalues = np.array([[nmax, sim_num, tshift, evidence,evidence_error]])
    else:
        outvalues = np.array([summary_titles,[nmax, sim_num, tshift, evidence,evidence_error]])

    with open(output_file, 'a') as file:
        writer = csv.writer(file)
        if (outvalues.shape)[0]>1 :
            writer.writerows(outvalues)
        else:
            writer.writerow(outvalues[0])

    return

def export_bestvals_files(best_data_file,postsamps,pars):
    
    tshift, lenpriors, labels = pars
    
    sigma_vars_m = np.empty(lenpriors)
    sigma_vars_p = np.empty(lenpriors)
    sigma_vars = np.empty(lenpriors)
    sigma_vars_ml = np.empty(lenpriors)
    for i in range(lenpriors): 
        amps_aux = postsamps[:,i]
        sigma_vars_m[i] = np.quantile(amps_aux, 0.05)
        sigma_vars[i] = np.quantile(amps_aux, 0.5)
        sigma_vars_ml[i] = postsamps[-1,i]
        sigma_vars_p[i] = np.quantile(amps_aux, 0.95)

        sigma_vars_all = [sigma_vars,sigma_vars_ml,sigma_vars_m,sigma_vars_p]
        sigma_vars_all=np.stack([sigma_vars,sigma_vars_ml,sigma_vars_m,sigma_vars_p], axis=0)

        key =['max val','max val ml','lower bound','higher bound']
        dfslist = [pd.DataFrame(np.concatenate(([tshift],sigma_vars_all[i])).reshape((-1,lenpriors+1)), columns=np.concatenate((['tshift'],labels)), index = [key[i]]) for i in range(4)]
        df2 = pd.concat(dfslist)
        if os.path.exists(best_data_file):
             df2.to_csv(best_data_file, mode='a', header=False,index = True)
        else:
            df2.to_csv(best_data_file, index = True)


            
def define_labels(dim,model,fitnoise):
    wstr = r'$\omega_'

    if model == 'w-tau':
        taustr = r'$\tau_'
    elif model == 'w-q':
        taustr = r'$q_'
    elif model == 'w-tau-fixed':
        taustr = r'$dumb_var}'
    elif model == 'w-tau-fixed-m-af':
        taustr = r'$\tau_'

    ampstr = r'$A_'
    phasestr =  r'$\phi_'

    w_lab = [None] * dim
    tau_lab = [None] * dim
    amp_lab =  [None] * dim
    pha_lab =  [None] * dim
    mass_lab =  ['mass']
    spin_lab  =  ['spin']

    for i in range(dim):
        w_lab[i] = wstr+str(i)+'$'
        tau_lab[i] = taustr+str(i)+'$'
        amp_lab[i] = ampstr+str(i)+'$'
        pha_lab[i] = phasestr+str(i)+'$'


    labels = np.concatenate((w_lab,tau_lab,amp_lab,pha_lab))

    if model=='w-tau-fixed':
        labels = np.concatenate((amp_lab,pha_lab))

    if model=='w-tau-fixed-m-af':
        pha_lab[i] = phasestr+str(i)+'$'

        labels = np.concatenate((amp_lab,pha_lab,mass_lab,spin_lab))

    if fitnoise:
        noise_lab = ['noise']
        labels = np.concatenate((labels,noise_lab))
    
    return labels

def get_truths(model,pars,fitnoise):
    w, tau, mf, af , npamps = pars
    if model == 'w-q':
        tau_val = np.pi*w*tau
        truths = np.concatenate((w,tau_val,npamps))
    elif model == 'w-tau':
        tau_val = tau
        truths = np.concatenate((w,tau_val,npamps))
    elif model == 'w-tau-fixed':
        truths = npamps
    elif model == 'w-tau-fixed-m-af':
        truths = np.concatenate((npamps,[mf],[af]))

    if fitnoise:
        truths = np.concatenate((truths,[1]))
        
    return truths

def get_best_amps(pars,parser=None,nr_code=None):
    nmax,model,samps_tr,half_points = pars

    
    if model=='w-tau-fixed':
        rg = (nmax+1)
    elif  model=='w-tau-fixed':
        rg = (nmax+1)+2
    else:
        rg = (nmax+1)*2


    if model=='w-tau-fixed-a-mf':
        npamps = np.empty((nmax+1))
        for i in range(0,(nmax+1)):
            amps_aux = samps_tr[i+rg][half_points:-1]
            npamps[i] = np.quantile(amps_aux, 0.5)
    else :
        npamps = np.empty((nmax+1)*2)
        for i in range(0,(nmax+1)*2):
            amps_aux = samps_tr[i][half_points:-1]
        npamps[i] = np.quantile(amps_aux, 0.5)

    if nr_code == 'Mock-data':
        nm_mock = parser.get('rd-mock-parameters','nm_mock')
        nm_mock = np.int(nm_mock)
        amp_mock=np.empty(nm_mock+1)
        ph_mock=np.empty(nm_mock+1)
        for i in range(nm_mock+1): 
            amp_mockp = parser.get('rd-mock-parameters','amp'+str(i))
            amp_mock[i] = np.float(amp_mockp)
            ph_mockp=parser.get('rd-mock-parameters','phase'+str(i))
            ph_mock[i] = np.float(ph_mockp)
    
        npamps = np.concatenate((amp_mock,ph_mock)) 
    return npamps

def convert_m_af_2_w_tau_post(res,fitnoise=False):
    
    samples_2=res.samples
    samps=f2.results.samples
    
    if fitnoise:
        fmass_spin=(samps.T)[-3:-1].T
    else:
        fmass_spin=(samps.T)[-2:].T
        #fmass_spin=new_samples[-2:]
    fmass_spin_dist=[None]*len(fmass_spin)
    weight=np.exp(res.logwt - res.logz[-1])
    for i in range(len(fmass_spin)):
        fmass_spin_dist[i]=np.concatenate(dict_omega[qnm_model](fmass_spin[i,0],fmass_spin[i,1],2,2))
        
    fmass_spin_dist_v2=np.asarray(fmass_spin_dist)
    new_samples = dyfunc.resample_equal(fmass_spin_dist_v2, weight)   
    
    return new_samples

def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)
        
def rm_files(files):
    """ rm all old files """
    for i in files:
        if os.path.exists(i):
            os.remove(i)