#!/usr/bin/env python
# coding: utf-8

# In[63]:


#Import relevant modules, import data and all that
import numpy as np
from scipy import interpolate
import corner
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import rc
from configparser import ConfigParser
import codecs
#plt.rcParams['font.family'] = 'DejaVu Sans'
#rc('text', usetex=True)
plt.rcParams.update({'font.size': 16.5})

import ptemcee
#from pycbc.pool import choose_pool
from multiprocessing import Pool
import h5py
import inspect
import pandas as pd
import json
import qnm
import random
import dynesty
from dynesty import plotting as dyplot
import os
import csv
import argparse
import scipy.optimize as optimization
from scipy.optimize import minimize


#Remember to change the following global variables
#rootpath: root path to nr data
#npoints: number of points you re using for your sampling
#nmax: tone index --> nmax = 0 if fitting the fundamental tone
#tshift: time shift after the strain peak
#vary_fund: whether you vary the fundamental frequency. Works in the model_dv function.

try:
    parser = argparse.ArgumentParser(description="Simple argument parser")
    parser.add_argument("-c", action="store", dest="config_file")
    result = parser.parse_args()
    config_file=result.config_file
    parser = ConfigParser()
    parser.read(config_file)
    parser.sections()
except SystemExit: 
    parser = ConfigParser()
    parser.read('config.ini')
    parser.sections()
    pass


# In[36]:


# path
rootpath=parser.get('nr-paths','rootpath')

simulation_path_1 = parser.get('nr-paths','simulation_path_1')
simulation_path_2 = parser.get('nr-paths','simulation_path_2')
metadata_file = parser.get('nr-paths','metadata_json')
simulation_number = parser.get('nr-paths','simulation_number')
simulation_number = np.int(simulation_number)

output_folder = parser.get('output-folder','output-folder')
overwrite = parser.get('overwrite','overwrite')


# In[37]:


if not os.path.exists(output_folder):
    os.mkdir(output_folder)
    print("Directory " , output_folder ,  " Created ")


# In[38]:


# time config

tshift=parser.get('time-setup','tshift')
tshift = np.float(tshift)

tend=parser.get('time-setup','tend')
tend = np.float(tend)

t_align=parser.get('time-setup','t_align')
t_align = np.float(t_align)


# In[39]:


# n-tones & nlive

nmax=parser.get('n-tones','nmax')
nmax = np.int(nmax)

npoints=parser.get('n-live-points','npoints')
npoints = np.int(npoints)


# In[40]:


# model
model=parser.get('rd-model','model')
if model == 'w-tau':
    tau_var_str='tau'
elif model == 'w-q':
    tau_var_str='q'
elif model == 'w-tau-fixed':
    tau_var_str='q'

print('model:',model)
print('nmax',nmax)


# In[41]:


tshift


# In[42]:


output_folder_1=output_folder+'/'+model+'-nmax'+str(nmax)
if not os.path.exists(output_folder_1):
    os.mkdir(output_folder_1)
    print("Directory " , output_folder_1 ,  " Created ")


# In[43]:


# loading priors
w_mins=np.empty(nmax+1)
w_maxs=np.empty(nmax+1)
tau_mins=np.empty(nmax+1)
tau_maxs=np.empty(nmax+1)
a_mins=np.empty(nmax+1)
a_maxs=np.empty(nmax+1)
ph_mins=np.empty(nmax+1)
ph_maxs=np.empty(nmax+1)


for i in range(nmax+1): 
    wp_min=parser.get('prior-w'+str(i),'w'+str(i)+'_min')
    w_mins[i] = np.float(wp_min)
    
    wp_max=parser.get('prior-w'+str(i),'w'+str(i)+'_max')
    w_maxs[i] = np.float(wp_max)
    
    taup_min=parser.get('prior-'+tau_var_str+str(i),tau_var_str+str(i)+'_min')
    tau_mins[i] = np.float(taup_min)
    
    taup_max=parser.get('prior-'+tau_var_str+str(i),tau_var_str+str(i)+'_max')
    tau_maxs[i] = np.float(taup_max)
    
    amp0_min=parser.get('prior-amp'+str(i),'amp'+str(i)+'_min')
    a_mins[i] = np.float(amp0_min)
    
    amp1_max=parser.get('prior-amp'+str(i),'amp'+str(i)+'_max')
    a_maxs[i] = np.float(amp1_max)
    
    phase_min=parser.get('prior-phase'+str(i),'phase'+str(i)+'_min')
    ph_mins[i] = np.float(phase_min)*2*np.pi
    
    phase_max=parser.get('prior-phase'+str(i),'phase'+str(i)+'_max')
    ph_maxs[i] = np.float(phase_max)*2*np.pi
    
priors_min = np.concatenate((w_mins,tau_mins,a_mins,ph_mins))
priors_max = np.concatenate((w_maxs,tau_maxs,a_maxs,ph_maxs))
prior_dim = len(priors_min)

if model == 'w-tau-fixed':
    priors_min = np.concatenate((a_mins,ph_mins))
    priors_max = np.concatenate((a_maxs,ph_maxs))
    prior_dim = len(priors_min)


# In[44]:


vary_fund = True

#sampler parameters
dim = nmax+1
ndim = 4*dim
numbins = 32 #corner plot parameter - how many bins you want
datacolor = '#105670' #'#4fa3a7'
pkcolor = '#f2c977' #'#ffb45f'
mediancolor = '#f7695c' #'#9b2814'

#Import data and necessary functions
#TimeOfMaximum
def FindTmaximum(y):
    #Determines the maximum absolute value of the complex waveform
    absval = np.sqrt(y[:,1]*y[:,1]+y[:,2]*y[:,2])
    vmax=np.max(absval)
    index = np.argmax(absval == vmax)
    timemax=y[index,0]
    return timemax


def EasyMatchT(t,h1,h2,tmin,tmax):
    #Computes the match for complex waveforms
    pos = np.argmax(t >= (tmin));
    
    h1red=h1[pos:-1];
    h2red=h2[pos:-1];
    
    norm1=np.sum(np.abs(h1red)**2)
    norm2=np.sum(np.abs(h2red)**2)

    myTable=h1red*np.conjugate(h2red)
    res=((np.sum(myTable)/np.sqrt(norm1*norm2))).real
    
    return res

def wRD_to_f_Phys(f,M):
    c=2.99792458*10**8;G=6.67259*10**(-11);MS=1.9885*10**30;
    return (c**3/(M*MS*G*2*np.pi))*f

def tauRD_to_t_Phys(tau,M):
    c=2.99792458*10**8;G=6.67259*10**(-11);MS=1.9885*10**30;
    return ((M*MS*G)/c**3)*tau


# In[45]:


#This loads the 22 mode data
gw = {}
gw[simulation_number] = h5py.File(simulation_path_1, 'r')
gw_sxs_bbh_0305 = gw[simulation_number]["Extrapolated_N3.dir"]["Y_l2_m2.dat"]


gw5 = {}
gw5[simulation_number] = h5py.File(simulation_path_2, 'r')
gw5_sxs_bbh_0305 = gw5[simulation_number]["Extrapolated_N3.dir"]["Y_l2_m2.dat"]
# Remember to download metadata.json from the simulation with number: 0305. Download Lev6/metadata.json
# This postprocesses the metadata file to find the final mass and final spin
metadata = {}
with open(metadata_file) as file:
    metadata[simulation_number] = json.load(file)

af = metadata[simulation_number]['remnant_dimensionless_spin'][-1]
mf = metadata[simulation_number]['remnant_mass']

#times --> x axis of your data
times = gw_sxs_bbh_0305[:,0]
tmax=FindTmaximum(gw_sxs_bbh_0305)
times = times - tmax

#times 6--> x axis of your data
times5 = gw5_sxs_bbh_0305[:,0]
tmax5=FindTmaximum(gw5_sxs_bbh_0305)
times5 = times5 - tmax5


# In[46]:


#Select the data from 0 onwards
position = np.argmax(times >= (t_align))
position5 = np.argmax(times5 >= (t_align))
gw_sxs_bbh_0305rd=gw_sxs_bbh_0305[position+1:-1]
gw_sxs_bbh_0305rd5=gw5_sxs_bbh_0305[position5+1:-1]
timesrd=gw_sxs_bbh_0305[position:-1][:,0][:-1]-tmax
timesrd5=gw5_sxs_bbh_0305[position5:-1][:,0][:-1]-tmax5


# In[47]:


#Test plot real part (data was picked in the last cell). Aligning in time
plt.figure(figsize = (12, 8))
plt.plot(timesrd, gw_sxs_bbh_0305rd[:,1], "r", alpha=0.3, lw=3, label=r'$Lev6$: real')
plt.plot(timesrd, np.sqrt(gw_sxs_bbh_0305rd[:,1]**2+gw_sxs_bbh_0305rd[:,2]**2), "r", alpha=0.3, lw=3, label=r'$Lev5\,amp$')
plt.plot(timesrd5, gw_sxs_bbh_0305rd5[:,1], "b", alpha=0.3, lw=3, label=r'$Lev5: real$')
plt.plot(timesrd5, np.sqrt(gw_sxs_bbh_0305rd5[:,1]**2+gw_sxs_bbh_0305rd5[:,2]**2), "b", alpha=0.3, lw=3, label=r'$Lev5\,amp$')
plt.legend()


# In[48]:


#Test plot im part (data was picked in the last cell). Aligning in time
plt.figure(figsize = (12, 8))
plt.plot(timesrd, gw_sxs_bbh_0305rd[:,2], "r", alpha=0.3, lw=3, label=r'$Lev6: imag$')
plt.plot(timesrd, np.sqrt(gw_sxs_bbh_0305rd[:,1]**2+gw_sxs_bbh_0305rd[:,2]**2), "r", alpha=0.3, lw=3, label=r'$Lev5\,amp$')
plt.plot(timesrd5, gw_sxs_bbh_0305rd5[:,2], "b", alpha=0.3, lw=3, label=r'$Lev5: imag$')
plt.plot(timesrd5, np.sqrt(gw_sxs_bbh_0305rd5[:,1]**2+gw_sxs_bbh_0305rd5[:,2]**2), "b", alpha=0.3, lw=3, label=r'$Lev5\,amp$')
plt.legend()


# In[49]:


# Depending on nmax, you load nmax number of freqs. and damping times from the qnm package
omegas = [qnm.modes_cache(s=-2,l=2,m=2,n=i)(a=af)[0] for i in range (0,dim)]
w = (np.real(omegas))/mf
tau=-1/(np.imag(omegas))*mf


# In[50]:


gwnew_re = interpolate.interp1d(timesrd, gw_sxs_bbh_0305rd[:,1], kind = 'cubic')
gwnew_im = interpolate.interp1d(timesrd, gw_sxs_bbh_0305rd[:,2], kind = 'cubic')

gwnew_re5 = interpolate.interp1d(timesrd5, gw_sxs_bbh_0305rd5[:,1], kind = 'cubic')
gwnew_im5 = interpolate.interp1d(timesrd5, gw_sxs_bbh_0305rd5[:,2], kind = 'cubic')


# In[51]:


if timesrd5[-1]>= timesrd[-1]: 
    timesrd_final = timesrd
else:
    timesrd_final = timesrd5

gwdatanew_re = gwnew_re(timesrd_final)
gwdatanew_im = gwnew_im(timesrd_final)
gwdatanew_re5 = gwnew_re5(timesrd_final)
gwdatanew_im5 = gwnew_im5(timesrd_final)

gwdatanew = gwdatanew_re - 1j*gwdatanew_im
gwdatanew5 = gwdatanew_re5- 1j*gwdatanew_im5


# In[52]:


mismatch=1-EasyMatchT(timesrd_final,gwdatanew,gwdatanew5,0,0+90)
error=np.sqrt(2*mismatch)
print(mismatch)


# In[53]:


# Phase alignement
phas = np.angle(gwdatanew)
phas = np.unwrap(phas)
phas5 = np.angle(gwdatanew5)
phas5 = np.unwrap(phas5)
plt.plot(timesrd_final, phas, "r", alpha=0.3, lw=3, label=r'$phase$')
plt.plot(timesrd_final, phas5, "blue", alpha=0.3, lw=3, label=r'$phase$')


# In[54]:


position = np.argmax(timesrd_final >= (t_align))
dphase = phas5[position]-phas[position]
print(dphase)

gwdatanew = (gwdatanew_re - 1j*gwdatanew_im)*np.exp(1j*dphase)
#gw_sxs_bbh_0305rd6=gw6_sxs_bbh_0305[position6:-1]
#timesrd=gw_sxs_bbh_0305[position:-1][:,0][:920]
phas = np.angle(gwdatanew)
phas = np.unwrap(phas)

phas5 = np.angle(gwdatanew5)
phas5 = np.unwrap(phas5)
plt.plot(timesrd_final, phas, "r", alpha=0.3, lw=3, label=r'$phase$')
plt.plot(timesrd_final, phas5, "blue", alpha=0.3, lw=3, label=r'$phase$')


# In[55]:


mismatch=1-EasyMatchT(timesrd_final,gwdatanew,gwdatanew5,0,+90)
print(mismatch)
error = np.sqrt(gwdatanew*gwdatanew-2*gwdatanew*gwdatanew5+gwdatanew5*gwdatanew5)


# In[56]:


#Test the new interpolated data
plt.figure(figsize = (12, 8))
plt.plot(timesrd_final, gwdatanew.real, "r", alpha=0.3, lw=2, label='Lev6')
plt.plot(timesrd_final, gwdatanew5.real, "b", alpha=0.3, lw=2, label='Lev5')
plt.plot(timesrd_final, error.real, "b", alpha=0.3, lw=2, label='error')
plt.legend()


# In[57]:


#Test the error data
plt.figure(figsize = (12, 8))
plt.plot(timesrd_final, error.real, "b", alpha=0.3, lw=2, label='error real')
plt.plot(timesrd_final, error.imag, "r", alpha=0.3, lw=2, label='error imag')
plt.plot(timesrd_final, np.sqrt(error.imag**2+error.real**2), "r", alpha=0.3, lw=2, label='all error')
plt.legend()


# In[58]:


#Take the piece of waveform you want
position_in = np.argmax(timesrd_final >= tshift)
position_end = np.argmax(timesrd_final >= tend)
timesrd_final_tsh = timesrd_final[position_in:position_end]
gwdatanew_re_tsh = gwdatanew_re[position_in:position_end]
gwdatanew_im_tsh = gwdatanew_im[position_in:position_end]
error_tsh=error[position_in:position_end]


# In[59]:


#Fitting
#RD model for nmax tones. Amplitudes are in (xn*Exp[i yn]) version. Used here.
def model_dv_q(theta):
    #x0, y0= theta
    #Your nmax might not align with the dim of theta. Better check it here.
    assert int(len(theta)/4) == dim, 'Please recheck your n and parameters'
    
    wvars = theta[ : (dim)]
    qvars = theta[(dim) : 2*(dim)]
    xvars = theta[2*(dim) : 3*(dim)]
    yvars = theta[3*(dim) : ]
    
    ansatz = 0
    for i in range (0,dim):
        ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-timesrd_final_tsh*np.pi*wvars[i]/qvars[i])*(np.cos(wvars[i]*timesrd_final_tsh)-1j*np.sin(wvars[i]*timesrd_final_tsh))
    # -1j to agree with SXS convention
    return ansatz

def model_dv_tau(theta):
    #x0, y0= theta
    #Your nmax might not align with the dim of theta. Better check it here.
    assert int(len(theta)/4) == dim, 'Please recheck your n and parameters'
    
    wvars = theta[ : (dim)]
    tvars = theta[(dim) : 2*(dim)]
    xvars = theta[2*(dim) : 3*(dim)]
    yvars = theta[3*(dim) : ]
    
    ansatz = 0
    for i in range (0,dim):
        ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-timesrd_final_tsh/tvars[i]) * (np.cos(wvars[i]*timesrd_final_tsh)-1j*np.sin(wvars[i]*timesrd_final_tsh))
    # -1j to agree with SXS convention
    return ansatz

def model_dv(theta):
    #x0, y0= theta
    #Your nmax might not align with the dim of theta. Better check it here.
    xvars = theta[ : (dim)]
    yvars = theta[(dim) : 2*(dim)]
    
    ansatz = 0
    for i in range (0,dim):
        ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-timesrd_final_tsh/tau[i]) * (np.cos(w[i]*timesrd_final_tsh)-1j*np.sin(w[i]*timesrd_final_tsh))
    # -1j to agree with SXS convention
    return ansatz

# Logprior distribution. It defines the allowed range my variables can vary over. 
#It works for the (xn*Exp[iyn]) version. 

def prior_transform(cube):
    for i in range(prior_dim):
        cube[i] =  priors_min[i]+ cube[i]*(priors_max[i]-priors_min[i])
    return cube

# LogLikelihood function. It is just a Gaussian loglikelihood based on computing the residuals^2
def log_likelihood(theta):
    modelev = dict[model](theta)
    result = -np.sum(((gwdatanew_re_tsh - modelev.real)**2+(gwdatanew_im_tsh - modelev.imag)**2)/(2*(error_tsh.real**2+error_tsh.imag**2)))
    if np.isnan(result):
        return -np.inf
    return result


# Logposterior distribution for the residuals case.
# The evidence is just a normalization factor
def log_probability(theta):
    lp = log_prior(theta)
    if not np.isfinite(lp):
        return -np.inf
    return lp + log_likelihood(theta)


# In[60]:


dict = {'w-tau': model_dv_tau , 'w-q': model_dv_q, 'w-tau-fixed': model_dv}


# In[64]:


#I need to provid an initial guess for 4*(nmax+1) the parameters
np.random.seed(42)
nll = lambda *args: -log_likelihood(*args)
initial = np.ones(prior_dim)
soln = minimize(nll, initial)
#x0_ml, y0_ml, a0_ml, b0_ml = soln.x
print("Maximum likelihood estimates:")
vars_ml=soln.x
print(vars_ml)


# In[75]:


f2=dynesty.NestedSampler(log_likelihood, prior_transform, prior_dim, nlive=npoints,sample='rwalk')
f2.run_nested()


# In[91]:


wstr = r'$\omega_'

if model == 'w-tau':
    taustr = r'$\tau_'
elif model == 'w-q':
    taustr = r'$q_'
elif model == 'w-tau-fixed':
    taustr = r'$dumb_var}'
    
ampstr = r'$A_'
phasestr =  r'$\phi_'

w_lab = [None] * dim
tau_lab = [None] * dim
amp_lab =  [None] * dim
pha_lab =  [None] * dim

for i in range(dim):
    w_lab[i] = wstr+str(i)+'$'
    tau_lab[i] = taustr+str(i)+'$'
    amp_lab[i] = ampstr+str(i)+'$'
    pha_lab[i] = phasestr+str(i)+'$'

labels = np.concatenate((w_lab,tau_lab,amp_lab,pha_lab))

if model=='w-tau-fixed':
    labels = np.concatenate((amp_lab,pha_lab))


# In[92]:


if model=='w-tau-fixed':
    rg = (nmax+1)
else:
    rg = (nmax+1)*2

samps=f2.results.samples
samps_tr=np.transpose(samps)
npamps = np.empty((nmax+1)*2)
half_points=int(round((len(samps_tr[0])/1.25)))

if model!='w-tau-fixed':
    for i in range(0,(nmax+1)*2):
        amps_aux = samps_tr[i+rg][half_points:-1]
        npamps[i] = np.quantile(amps_aux, 0.5)
else:
    for i in range(0,(nmax+1)*2):
        amps_aux = samps_tr[i][half_points:-1]
        npamps[i] = np.quantile(amps_aux, 0.5)


# In[93]:


res = f2.results
res.samples_u.shape
res.summary()
samps=f2.results.samples


# In[79]:


evidence = res.logz[-1]
evidence_error = res.logzerr[-1]


# In[80]:


summary_titles=['n','id','t_shift','dlogz','dlogz_err']


# In[81]:


f = output_folder_1+'/summary'+str(simulation_number)+'_'+model+'_nmax_'+str(nmax)+'.csv'


# In[82]:


if os.path.exists(f):
    outvalues = np.array([[nmax, simulation_number, tshift, evidence,evidence_error]])
else:
    outvalues = np.array([summary_titles,[nmax, simulation_number, tshift, evidence,evidence_error]])

with open(f, 'a') as file:
    writer = csv.writer(file)
    if (outvalues.shape)[0]>1 :
        writer.writerows(outvalues)
    else:
        writer.writerow(outvalues[0])


# In[83]:


samps=f2.results.samples
samps_tr=np.transpose(samps)


# In[84]:


sigma_vars_m = np.empty(prior_dim)
sigma_vars_p = np.empty(prior_dim)
sigma_vars = np.empty(prior_dim)
for i in range(prior_dim): 
    amps_aux = samps_tr[i][half_points:-1]
    sigma_vars_m[i] = np.quantile(amps_aux, 0.1)
    sigma_vars[i] = np.quantile(amps_aux, 0.5)
    sigma_vars_p[i] = np.quantile(amps_aux, 0.9)


# In[85]:


sigma_vars_all = [sigma_vars,sigma_vars_m,sigma_vars_p]


# In[86]:


sigma_vars_all=np.stack([sigma_vars,sigma_vars_m,sigma_vars_p], axis=0)


# In[87]:


key =['max val','lower bound','higher bound']
file=output_folder_1+'/best_values_'+str(simulation_number)+'tshift_'+str(tshift)+'_'+model+'_nmax_'+str(nmax)+'.csv'
dfslist = [pd.DataFrame(sigma_vars_all[i].reshape((-1,prior_dim)), columns=labels, index = [key[i]]) for i in range(3)]
df2 = pd.concat(dfslist)
df2.to_csv(file, index = False)


# In[88]:


f = open(output_folder_1+'/best_sigmas_'+str(simulation_number)+'tshift_'+str(tshift)+'_'+model+'_'+str(nmax)+'.csv', 'w')
with f:

    writer = csv.writer(f)
    writer.writerows(map(lambda x: [x], df2))


# In[89]:


if model == 'w-q':
    tau_val = np.pi*w*tau
    truths = np.concatenate((w,tau_val,npamps))
elif model == 'w-tau':
    tau_val = tau
    truths = np.concatenate((w,tau_val,npamps))
elif model == 'w-tau-fixed':
    truths = npamps


# In[90]:


fg, ax = dyplot.cornerplot(res, color='blue', 
                           show_titles=True,
                           labels=labels,
                           quantiles=(0.05,0.5,0.95),
                           truths =truths,
                           truth_color='red',
)


# In[200]:


fg.savefig(output_folder_1+'/Dynesty_'+str(simulation_number)+'_'+model+'_nmax='+str(nmax)+'_tshift='+str(tshift)+'_'+str(npoints)+'_chainplot.png', format = 'png', bbox_inches = 'tight')


# In[ ]:


figband = plt.figure(figsize = (12, 9))
plt.plot(timesrd_final_tsh,gwdatanew_re_tsh, "green", alpha=0.9, lw=3, label=r'$res_{240}$')
plt.plot(timesrd_final_tsh,dict[model](vars_ml).real,'bo', alpha=0.9, lw=3, label=r'$fit$')

samples_res=samps[-20000:-1][0::100]
for i in samples_res:
    plt.plot(tshift, dict[model](i).real, "r-", alpha=0.01, lw=3)

plt.title(r'Comparison of the MC fit data and the $1-\sigma$ error band')
plt.legend()
plt.xlabel('t')
plt.ylabel('h')
plt.show()