Skip to content
Snippets Groups Projects
Select Git revision
  • 134ea3e1ae350f8fbd578c3f31617d42b9888164
  • master default protected
2 results

NR_Interpolate-0001_t_10M_wandt.py

Blame
  • NR_Interpolate-0001_t_10M_wandt.py 10.29 KiB
    #!/usr/bin/env python
    # coding: utf-8
    
    # ### Let's try the NR_Interpolate for the 0.0001 stepsize.
    
    # In[99]:
    
    
    #Import relevant modules, import data and all that
    import numpy as np
    from scipy import interpolate
    import corner
    import os
    os.environ['MPLCONFIGDIR'] = '/home/rayne.liu/.config/matplotlib'
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator
    from matplotlib import rc
    #plt.rcParams['font.family'] = 'DejaVu Sans'
    #rc('text', usetex=True)
    plt.rcParams.update({'font.size': 16.5})
    
    import ptemcee
    #from pycbc.pool import choose_pool
    from multiprocessing import Pool
    import h5py
    import inspect
    import pandas as pd
    import json
    import qnm
    import random
    
    #Remember to change the following global variables
    #rootpath: root path to nr data
    #npoints: number of points you re using for your sampling
    #nmax: tone index --> nmax = 0 if fitting the fundamental tone
    #tshift: time shift after the strain peak
    #vary_fund: whether you vary the fundamental frequency. Works in the model_dv function.
    
    rootpath= "/work/rayne.liu/git/rdstackingproject"#"/Users/RayneLiu/git/rdstackingproject"
    nmax=1
    tshift=10
    vary_fund = True
    
    #sampler parameters
    npoints = 10
    nwalkers = 20
    ntemps=12
    dim = nmax+1
    ndim = 4*dim
    burnin = 5  #How many points do you burn before doing the corner plot. You need to watch the convergence of the chain plot a bit.
                #This is trivial but often forgotten: this cannot be more than npoints! I usually use half the points.
    numbins = 42 #corner plot parameter - how many bins you want
    datacolor = '#105670' #'#4fa3a7'
    pkcolor = '#f2c977' #'#ffb45f'
    mediancolor = '#f7695c' #'#9b2814'
    
    #Import data and necessary functions
    
    #TimeOfMaximum
    def FindTmaximum(y):
        #Determines the maximum absolute value of the complex waveform
        absval = y[:,1]*y[:,1]+y[:,2]*y[:,2]
        vmax=np.max(absval)
        index = np.argmax(absval == vmax)
        timemax=gw_sxs_bbh_0305[index,0]
        return timemax
    
    
    
    
    #This loads the 22 mode data
    gw = {}
    gw["SXS:BBH:0305"] = h5py.File(rootpath+"/SXS/BBH_SKS_d14.3_q1.22_sA_0_0_0.330_sB_0_0_-0.440/Lev6/rhOverM_Asymptotic_GeometricUnits_CoM.h5", 'r')
    gw_sxs_bbh_0305 = gw["SXS:BBH:0305"]["Extrapolated_N2.dir"]["Y_l2_m2.dat"]
    
    # Remember to download metadata.json from the simulation with number: 0305. Download Lev6/metadata.json
    # This postprocesses the metadata file to find the final mass and final spin
    metadata = {}
    with open(rootpath+"/SXS/BBH_SKS_d14.3_q1.22_sA_0_0_0.330_sB_0_0_-0.440/Lev6/metadata.json") as file:
        metadata["SXS:BBH:0305"] = json.load(file)
    
    af = metadata["SXS:BBH:0305"]['remnant_dimensionless_spin'][-1]
    mf = metadata["SXS:BBH:0305"]['remnant_mass']
    
    
    
    #times --> x axis of your data
    times = gw_sxs_bbh_0305[:,0]
    tmax=FindTmaximum(gw_sxs_bbh_0305)
    t0=tmax +tshift
    
    #Select the data from t0 onwards
    position = np.argmax(times >= (t0))
    gw_sxs_bbh_0305rd=gw_sxs_bbh_0305[position:-1]
    timesrd=gw_sxs_bbh_0305[position:-1][:,0][:920]
    #print(timesrd[0])
    #print(t0) #(This checks that timesrd[0] is indeed t0 - acturally this is a bit off due to stepsize issues, 
              #but nvm, we'll fix it right away)
    t0 = timesrd[0]
    #print(t0)
    timespan = timesrd - t0
    gwdata_re = gw_sxs_bbh_0305rd[:,1][:920]
    gwdata_im = gw_sxs_bbh_0305rd[:,2][:920]
    
    # Depending on nmax, you load nmax number of freqs. and damping times from the qnm package
    omegas = [qnm.modes_cache(s=-2,l=2,m=2,n=i)(a=af)[0] for i in range (0,dim)]
    w = (np.real(omegas))/mf
    tau=-1/(np.imag(omegas))*mf
    
    
    # In[84]:
    
    gwnew_re = interpolate.interp1d(timespan, gwdata_re, kind = 'cubic')
    gwnew_im = interpolate.interp1d(timespan, gwdata_im, kind = 'cubic')
    
    
    # In[87]:
    
    
    timespan_new = np.linspace(0, timespan[-1], len(timespan)*1000)
    gwdatanew_re = gwnew_re(timespan_new)
    gwdatanew_im = gwnew_im(timespan_new)
    
    
    
    # In[92]:
    
    
    #Test the new interpolated data
    figtest = plt.figure(figsize = (12, 8))
    plt.plot(timespan, gwdata_re, "r", alpha=0.3, lw=2, label='Before_re')
    plt.plot(timespan_new, gwdatanew_re, "b", alpha=0.3, lw=2, label='After_re')
    plt.plot(timespan, gwdata_im, alpha=0.3, lw=2, label='Before_im')
    plt.plot(timespan_new, gwdatanew_im, alpha=0.3, lw=2, label='After_im')
    plt.legend()
    figtest.savefig(rootpath + '/plotsmc/0001_interpolated_datatest_wandt.png', format='png', bbox_inches='tight', dpi=300)
    
    
    # ### Now the interpolation seems nice according to what we have above...let's start sampling!
    
    # In[100]:
    
    
    #Fitting
    #RD model for nmax tones. Amplitudes are in (xn*Exp[i yn]) version. Used here.
    def model_dv(theta):
        #x0, y0= theta
        #Your nmax might not align with the dim of theta. Better check it here.
        assert int(len(theta)/4) == dim, 'Please recheck your n and parameters'
        
        wvars = theta[ : (dim)]
        tvars = theta[(dim) : 2*(dim)]
        xvars = theta[2*(dim) : 3*(dim)]
        yvars = theta[3*(dim) : ]
        
        #if vary_fund == False:
        #    avars[0]=0
        #    bvars[0]=0
            
        ansatz = 0
        for i in range (0,dim):
            #bvars[1]=0
            #avars[1]=0
            ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-timespan_new/tvars[i]) * (np.cos(wvars[i]*timespan_new)-1j*np.sin(wvars[i]*timespan_new))
        # -1j to agree with SXS convention
        return ansatz
    
    # Logprior distribution. It defines the allowed range my variables can vary over. 
    #It works for the (xn*Exp[iyn]) version. 
    def log_prior(theta): 
        #Warning: we are specifically working with nmax=1 so here individual prior to the parameters are manually adjusted. This does not apply to all other nmax's.
        #avars = theta[ : (dim)]
        #bvars = theta[(dim) : 2*(dim)]
        #xvars = theta[2*(dim) : 3*(dim)]
        #yvars = theta[3*(dim) : ]
        omega0, omega1, tau0, tau1, xvar0, xvar1, yvar0, yvar1 = theta
        if tshift == 0:
            if all([0.45 <= omega0 <= 0.63, 0.27 <= omega1 <= 0.6, 0. <= tau0 <= 30., 0. <= tau1 <= 20., \
                0 <= xvar0 <= 6, 0 <= xvar1 <= 6, -np.pi <= yvar0 <= np.pi, 0. <= yvar1 <= 2*np.pi]):        
                return 0.0
        
        elif tshift == 10:
            if all([0.56 <= omega0 <= 0.64, 0.4 <= omega1 <= 0.56, 6.45 <= tau0 <= 13.6, 1.94 <= tau1 <= 10., \
                0. <= xvar0 <= 1.5, 0. <= xvar1 <= 1.2, 0 <= yvar0 <= 2*np.pi, 0. <= yvar1 <= 2*np.pi]):        
                return 0.0
        
        return -np.inf
    
    
    # LogLikelihood function. It is just a Gaussian loglikelihood based on computing the residuals^2
    def log_likelihood(theta):
        modelev = model_dv(theta)
        result = -np.sum((gwdatanew_re - (modelev.real))**2+(gwdatanew_im - (modelev.imag))**2)
        if np.isnan(result):
            return -np.inf
        return result
    
    
    # Logposterior distribution for the residuals case.
    # The evidence is just a normalization factor
    def log_probability(theta):
        lp = log_prior(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + log_likelihood(theta)
    
    
    # In[101]:
    
    
    #This cell uses the tshift=10 results
    #Set the number of cores of your processors
    #pool = choose_pool(1)
    #pool.size = 1
    np.random.seed(42)
    pos = np.array([random.uniform(0.58,0.62), random.uniform(0.42,0.54), random.uniform(7., 12.),                 random.uniform(2.,4.), random.uniform(0.5,1.), random.uniform(0.5, 1.), random.uniform(1., 2.),                 random.uniform(1., 2.)])
    pos = list(pos)
    pos += 1e-5 * np.random.randn(ntemps, nwalkers, ndim)
    with Pool() as pool:
        sampler = ptemcee.Sampler(nwalkers, ndim, log_likelihood, log_prior, ntemps=ntemps, pool=pool)
        sampler.run_mcmc(pos,npoints)
    
    dim = 2
    paramlabels_w = [r'$\omega_'+str(i)+'$' for i in range (dim)]
    paramlabels_t = [r'$\tau_'+str(i)+'$' for i in range (dim)]
    paramlabels_x = [r'$x_'+str(i)+'$' for i in range (dim)]
    paramlabels_y = [r'$y_'+str(i)+'$' for i in range (dim)] 
    
    paramlabels = paramlabels_w + paramlabels_t + paramlabels_x + paramlabels_y
    
    print('The chain plot:')
    #Chain plot
    figchain, axes = plt.subplots(ndim, 1, sharex=True, figsize=(12, 4*(4)))
    for i in range(ndim):
        axes[i].plot(sampler.chain[0,:, :, i].T, color="k", alpha=0.4, rasterized=True)
        axes[i].yaxis.set_major_locator(MaxNLocator(5))
        axes[i].set_ylabel(paramlabels[i])
    axes[-1].set_xlabel('Iterations')
    figchain.savefig(rootpath + '/plotsmc/0001_10M_interpolated_chainplot_wandt_'+str(nwalkers)+'walkers_'+str(npoints)+'pts.png', format='png', bbox_inches='tight', dpi=300)
    
    for temp in range(ntemps):
        dftemp = pd.DataFrame(sampler.chain[temp,:, :, :].reshape((-1, ndim)), columns=paramlabels)
        dftemp.to_csv(rootpath+'/plotsmc/0001_10M_interpolated'+'_nmax'+str(nmax)+'_tshift'+str(tshift)+'_'+str(npoints)+'pt_temp'+str(temp)+'_chain.csv', index = False)
    
    print('We\'re using ptemcee. Our constraints:')
    #Burn samples, calculate peak likelihood value (not necessarily so in atlas) and make corner plot
    samples = sampler.chain[0,:, burnin:, :].reshape((-1, ndim))
    #samples for corner plot
    samples_corn = samples #if vary_fund == True else np.delete(samples, np.s_[0,2], 1)
    
    #print('Values with peak likelihood:')
    lglk = np.array([log_likelihood(samples[i]) for i in range(len(samples))])
    pk = samples[np.argmax(lglk)]
    #print('pk:')
    #print(pk)
    pk_corn = pk #if vary_fund == True else np.delete(pk, [0,2])
    #y_0 range needs some messaging to make the plot. But in order to make the whole picture consistent, better change the range of y_1 too.
    #if vary_fund == False:
    #    samples_corn.T[-dim:] -= np.pi #This indeed changes samples_corn itself
    #    pk[-dim:] -= np.pi
    
    #print('pkFalse:')
    #print(pk)
        
    #print(pk) 
    #Now calculate median (50-percentile) value
    median = np.median(samples_corn, axis=0)
    #print(samples)
    #print(samples_corn)
    
    figcorn = corner.corner(samples_corn, bins = numbins, hist_bin_factor = 5, color = datacolor, truths=pk_corn, truth_color = pkcolor, plot_contours = True, labels = paramlabels, quantiles=(0.05, 0.16, 0.5, 0.84, 0.95), levels=[1-np.exp(-0.5), 1-np.exp(-1.64 ** 2/2)], show_titles=True)
    
    
    #Extract the axes in order to add more important line plots
    naxes = len(pk_corn)
    axes = np.array(figcorn.axes).reshape((naxes, naxes))
    
    # Loop over the diagonal
    for i in range(naxes):
        ax = axes[i, i]
        ax.axvline(median[i], color=mediancolor)
    
    # Loop over the histograms
    for yi in range(naxes):
        for xi in range(yi):
            ax = axes[yi, xi]
            ax.axvline(median[xi], color=mediancolor)
            ax.axhline(median[yi], color=mediancolor)
            ax.plot(median[xi], median[yi], color = mediancolor, marker = 's')
    figcorn.savefig(rootpath + '/plotsmc/0001_10M_interpolated_cornerplot_wandt_'+'nmax'+str(nmax)+'_tshift'+str(tshift)+'_'+str(nwalkers)+'walkers_'+str(npoints)+'pts.png', format='png', bbox_inches='tight', dpi=300)