# Copyright (C) 2021 Xisco Jimenez Forteza
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
# Module to generate RD waveforms.

import numpy as np
import qnm
import os


f_fpars= [[2.95845, -2.58697, 0.0533469], [2.12539, -1.78054, 0.0865503], [1.74755, -1.44776, 0.123666], [1.78287, -1.53203, 0.129475], [2.04028, -1.83224, 0.112497]]
q_fpars=[[0.584077, 1.52053, -0.480658], [0.00561441, 0.630715, -0.432664], [-0.197965, 0.515956, -0.369706], [-0.275097, 0.455691, -0.331543], [-0.287596, 0.398514, -0.309799]]

class Ringdown_Spectrum:
    """RDown model generator"""
    def __init__(self,mf,af,l,m,n=4,s=-2,time=[],fixed=False):
        self.mf = mf
        self.af = af
        self.l = l
        self.m = m
        self.n = n
        self.time = time
        self.grav_220 = [qnm.modes_cache(s=s,l=self.l,m=self.m,n=i) for i in range (0,self.n+1)]
        self.dim = self.n+1
        self.fixed = fixed
        
        if len(self.time)==0:
            self.time = np.arange(0,100,0.1)
            
        if self.fixed:
            omegas_new=np.asarray([self.grav_220[i](a=self.af)[0] for i in range (0,self.dim)])
            self.w = (np.real(omegas_new))/self.mf
            self.tau=-1/(np.imag(omegas_new))*self.mf

               
    def QNM_spectrum(self):
        """ It computes the RD frequencies and damping times in NR units.
        """     
        omegas_new=np.asarray([self.grav_220[i](a=self.af)[0] for i in range (0,self.n+1)])
        w_m_a = (np.real(omegas_new))/self.mf
        tau_m_a=-1/(np.imag(omegas_new))*self.mf
    
        return (w_m_a, tau_m_a)


    def w_fpars_Berti(self,n):
        return f_fpars[n]

    def tau_qpars_Berti(self,n):
        return q_fpars[n]

    def mass_from_wtau(self,n,w,tau):
        f1,f2,f3 = w_fpars_Berti(n)
        q1,q2,q3 = tau_qpars_Berti(n)
        res=(f1 + f2*(2**(-1/q3)*((-2*q1 + w*tau)/q2)**(1/q3))**f3)/w
        return res

    def spin_from_wtau(self,n,w,tau):
        f1,f2,f3 = w_fpars_Berti(n)
        q1,q2,q3 = tau_qpars_Berti(n)
        res=1 - 2**(-1/q3)*((-2*q1 + w*tau)/q2)**(1/q3)
        return res

    def mass_from_wtau_loop(self,w,tau,l,m):
        res=[None]*dim
        for n in range (0,dim):
            f1,f2,f3 = w_fpars_Berti(n)
            q1,q2,q3 = tau_qpars_Berti(n)
            res[n]=(f1 + f2*(2**(-1/q3)*((-2*q1 + w[n]*tau[n])/q2)**(1/q3))**f3)/w[n]
        return res

    def spin_from_wtau_loop(self,w,tau,l,m):
        res=[None]*dim
        for n in range (0,dim):
            f1,f2,f3 = w_fpars_Berti(n)
            q1,q2,q3 = tau_qpars_Berti(n)
            res[n]= 1 - 2**(-1/q3)*((-2*q1 + w[n]*tau[n])/q2)**(1/q3)
        return res

    
    def rd_model_wtau(self,theta):
        """RD model parametrized with the damping time tau.
        """ 
        assert int(len(theta)/4) == self.dim, 'Please recheck your n and parameters'
    
        wvars = theta[ : (self.dim)]
        tvars = theta[(self.dim) : 2*(self.dim)]
        xvars = theta[2*(self.dim) : 3*(self.dim)]
        yvars = theta[3*(self.dim) : ]
    
        ansatz = 0
        for i in range (0,self.dim):
            ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-self.time/tvars[i]) * (np.cos(wvars[i]*self.time)-1j*np.sin(wvars[i]*self.time))
            # -1j to agree with SXS convention
        return ansatz
    
    def rd_model_wq(self,theta):
        """RD model parametrized with the quality factor q.
        """  
        assert int(len(theta)/4) == self.dim, 'Please recheck your n and parameters'
    
        wvars = theta[ : (self.dim)]
        qvars = theta[(self.dim) : 2*(self.dim)]
        xvars = theta[2*(self.dim) : 3*(self.dim)]
        yvars = theta[3*(self.dim) : ]
        
        ansatz = 0
        for i in range (0,self.dim):
            ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-self.time*np.pi*wvars[i]/qvars[i])*(np.cos(wvars[i]*self.time)-1j*np.sin(wvars[i]*self.time))
            # -1j to agree with SXS convention
        return ansatz
    
    def rd_model_wq_fixed(self,theta):
        """RD model parametrized with the damping time tau and with the QNM spectrum fixd to GR. 
        """ 
        xvars = theta[ : (self.dim)]
        yvars = theta[(self.dim) : 2*(self.dim)]
    
        ansatz = 0
        for i in range (0,self.dim):
            ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-self.time/self.tau[i]) * (np.cos(self.w[i]*self.time)-1j*np.sin(self.w[i]*self.time))
        # -1j to agree with SXS convention
        return ansatz
    
    
    def rd_model_wq_m_a(self,theta):
        """RD model parametrized with the damping time tau and with the QNM spectrum fixd to GR. The QNM spectrum is given from the mass and spin.
        """ 
        xvars = theta[ : (self.dim)]
        yvars = theta[(self.dim) : 2*(self.dim)]
        mass_vars = theta[-2]
        spin_vars = theta[-1]

        w_m_a , tau_m_a = QNM_spectrum

        ansatz = 0
        for i in range (0,dim):
            ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-timesrd_final_tsh/tau_m_a[i]) * (np.cos(w_m_a[i]*timesrd_final_tsh)-1j*np.sin(w_m_a[i]*timesrd_final_tsh))
        # -1j to agree with SXS convention
        return ansatz