Skip to content
Snippets Groups Projects
Select Git revision
  • fdbed9bc3b54062a848b45ed1df0a08b79cad63b
  • master default
2 results

test_LLO.py

Blame
  • Forked from finesse / pykat
    Source project has a limited visibility.
    RDptemcee.py 8.13 KiB
    #!/usr/bin/env python
    # coding: utf-8
    
    # In[185]:
    
    
    import numpy as np
    import corner
    #get_ipython().run_line_magic('matplotlib', 'inline')
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator
    from matplotlib import rc
    plt.rcParams['font.family'] = 'DejaVu Sans'
    rc('text', usetex=False)
    plt.rcParams.update({'font.size': 19})
    import ptemcee
    import qnm
    from pycbc.pool import choose_pool
    import random
    import h5py
    import json
    
    
    # In[186]:
    
    
    rootpath="/work/francisco.jimenez/sio"
    project_path=rootpath+"/git/rdstackingproject"
    
    
    # In[187]:
    
    
    # Depending on nmax, you load nmax number of freqs. and damping times from the qnm package
    nmax=1
    ndim = int(4*(nmax+1))
    cores=12
    ntemps = 20
    nwalkers = 800
    npoints = 5000
    burnin = 4000
    numbins = 42
    tshift=0
    tend=150
    
    
    # In[188]:
    
    
    chain_file = project_path+'/plotsmc/Test'+'nmax='+str(nmax)+'_tshift='+str(tshift)+'_tend='+str(tend)+'_'+str(npoints)+'pt_chain.png'
    chain_file_dat=project_path+'/plotsmc/Test'+'nmax='+str(nmax)+'_tshift='+str(tshift)+'_tend='+str(tend)+'_'+str(npoints)+'pt_chain.csv'
    corner_file = project_path+'/plotsmc/Test'+'nmax='+str(nmax)+'_tshift='+str(tshift)+'_tend='+str(tend)+'_'+str(npoints)+'pt_corner.png'
    
    
    # In[189]:
    
    
    datacolor = '#105670' #'#4fa3a7'
    pkcolor = '#f2c977' #'#ffb45f'
    mediancolor = '#f7695c' #'#9b2814'
    
    
    # In[190]:
    
    
    #TimeOfMaximum
    def FindTmaximum(y):
        #Determines the maximum absolute value of the complex waveform
        absval = y[:,1]*y[:,1]+y[:,2]*y[:,2]
        vmax=np.max(absval)
        index = np.argmax(absval == vmax)
        timemax=gw_sxs_bbh_0305[index,0]
        return timemax
    
    
    # In[191]:
    
    
    #This loads the 22 mode data
    gw = {}
    gw["SXS:BBH:0305"] = h5py.File(project_path+"/SXS/BBH_SKS_d14.3_q1.22_sA_0_0_0.330_sB_0_0_-0.440/Lev6/rhOverM_Asymptotic_GeometricUnits_CoM.h5", 'r')
    gw_sxs_bbh_0305 = gw["SXS:BBH:0305"]["Extrapolated_N2.dir"]["Y_l2_m2.dat"]
    
    # Remember to download metadata.json from the simulation with number: 0305. Download Lev6/metadata.json
    # This postprocesses the metadata file to find the final mass and final spin
    metadata = {}
    with open(project_path+"/SXS/BBH_SKS_d14.3_q1.22_sA_0_0_0.330_sB_0_0_-0.440/Lev6/metadata.json") as file:
        metadata["SXS:BBH:0305"] = json.load(file)
    
    af = metadata["SXS:BBH:0305"]['remnant_dimensionless_spin'][-1]
    mf = metadata["SXS:BBH:0305"]['remnant_mass']
    
    
    # In[192]:
    
    
    #times --> x axis of your data
    times = gw_sxs_bbh_0305[:,0]
    tmax=FindTmaximum(gw_sxs_bbh_0305)
    t0=tmax +tshift
    
    #Select the data from t0 onwards
    position = np.argmax(times>=t0)
    gw_sxs_bbh_0305rd=gw_sxs_bbh_0305[position:-1]
    
    t0_end=t0+tend
    #Select the data from [t0,t0+80]
    position = np.argmin(gw_sxs_bbh_0305rd[:,0]<=t0_end)
    gw_sxs_bbh_0305rd=gw_sxs_bbh_0305rd[0:position]
    
    timesrd=gw_sxs_bbh_0305rd[:,0]-tmax
    
    
    # In[193]:
    
    
    omegas = []
    for i in range (0,nmax+1):
        grav_220 = qnm.modes_cache(s=-2,l=2,m=2,n=i)
        omega = grav_220(a=af)[0]
        omegas.append(omega)
    w = (np.real(omegas))/mf
    tau=-1/(np.imag(omegas))*mf
    
    
    # In[194]:
    
    
    plt.figure(figsize = (12, 8))
    plt.plot(timesrd, gw_sxs_bbh_0305rd[:,1], label = r'Real')
    plt.plot(timesrd, gw_sxs_bbh_0305rd[:,2], label = r'Imag')
    plt.legend()
    
    
    # In[195]:
    
    
    def modelmock_v2(theta):
        """
        theta: comprised of alpha, beta, x and y
        """ 
        
        assert int(len(theta)/4) == nmax + 1, 'Please recheck your n and parameters'
        dim =int(len(theta)/4)        
        
        avars = [theta[4*i] for i in range (0, dim)]
        bvars = [theta[4*i+1] for i in range (0, dim)]
        xvars = [theta[4*i+2] for i in range (0, dim)]
        yvars = [theta[4*i+3] for i in range (0, dim)]        
    
        ansatz = 0
        for i in range (0,dim):
            tauvar=tau[i]*(1+bvars[i])
            wvar=w[i]*(1+avars[i])
            ansatz += (xvars[i]*np.exp(1j*yvars[i]))*np.exp(-timesrd/tauvar)*(np.cos(wvar*timesrd)-1j*np.sin(wvar*timesrd))
           
        return ansatz
    
    
    '''def log_prior(theta): 
        alpha0, beta0, xvar0, yvar0 = theta
        
        if all([-0.4 <= alpha0 <= 0.4, -1.0 <= beta0 <= 2.0, 0 <= xvar0 <= 10, -np.pi <= yvar0 <= np.pi]):        
                return 0.0
        return -np.inf
    '''
    
    def log_prior(theta):
        a_s = theta[0::4]
        b_s = theta[1::4]
        x_s = theta[2::4]
        y_s = theta[3::4]
        if all(-0.4 <= t <= 0.4 for t in a_s) and all(-1 <= t <= 2 for t in b_s) and all(0 <= t <= 15 for t in x_s) and all(-np.pi <= t <= np.pi for t in y_s):
            return 0.0
        return -np.inf
    
    # LogLikelihood function. It is just a Gaussian loglikelihood based on computing the residuals^2
    def log_likelihood(theta):
        model_mock = modelmock_v2(theta)
        
        return  -np.sum((gw_sxs_bbh_0305rd[:,1] - model_mock.real)**2+(gw_sxs_bbh_0305rd[:,2] - model_mock.imag)**2)
    
    # Logposterior distribution for the residuals case.
    # The evidence is just a normalization factor
    def log_probability(theta):
        lp = log_prior(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + log_likelihood(theta)
    
    
    # In[196]:
    
    
    paramlabels = []
    for i in range (nmax+1):
        sublabel = [r'$\alpha_' + str(i) + '$', r'$\beta_' + str(i) + '$', r'$x' + str(i) + '$',r'$y' + str(i) + '$']
        paramlabels += sublabel
    
    
    # In[197]:
    
    
    pool = choose_pool(cores)
    pool.size = cores
    np.random.seed(42)
    '''
    pos = np.array([random.uniform(-0.1,0.1), random.uniform(-0.1,0.1), random.uniform(0.5, 10), random.uniform(-3, 3),random.uniform(-0.1,0.1), random.uniform(-0.1,0.1), random.uniform(0.5, 10), random.uniform(-3, 3)])
    pos = list(pos)
    pos += 1e-5 * np.random.randn(ntemps, nwalkers, ndim)
    '''
    
    pos = [random.uniform(-0.1,0.1), random.uniform(-0.1,0.1), random.uniform(0.5, 10), random.uniform(-3, 3)]
    for i in range (1,nmax+1):
        pos_aux = [random.uniform(-0.1,0.1) ,random.uniform(-0.1,0.1) ,random.uniform(0.5, 10) ,random.uniform(-3, 3)]
        pos = pos + pos_aux
    pos += 1e-5 * np.random.randn(ntemps, nwalkers, ndim)
    
    sampler = ptemcee.Sampler(nwalkers, ndim, log_likelihood, log_prior, ntemps=ntemps)
    sampler.run_mcmc(pos,npoints);
    
    
    # In[198]:
    
    
    #Chain plot
    fig, axes = plt.subplots(ndim, 1, sharex=True, figsize=(12, 4*(4)))
    for i in range(ndim):
        axes[i].plot(sampler.chain[0,:, :, i].T, color="k", alpha=0.4, rasterized=True)
        axes[i].yaxis.set_major_locator(MaxNLocator(5))
        axes[i].set_ylabel(paramlabels[i])
    axes[-1].set_xlabel('Iterations')
    plt.show()
    
    fig.savefig(chain_file, format = 'png', dpi = 384, bbox_inches = 'tight')
    out = np.concatenate(sampler.chain[0,:])
    np.savetxt(chain_file_dat,out, fmt='%d')
    fig.savefig(chain_file, format = 'png', dpi = 384, bbox_inches = 'tight')
    
    
    # In[199]:
    
    
    #Burn samples, calculate peak likelihood value (not necessarily so in atlas) and make corner plot
    samples = sampler.chain[0,:, burnin:, :].reshape((-1, ndim))
    #samples for corner plot
    samples_corn = samples #if vary_fund == True else np.delete(samples, np.s_[0,2], 1)
    
    #print('Values with peak likelihood:')
    lglk = np.array([log_likelihood(samples[i]) for i in range(len(samples))])
    pk = samples[np.argmax(lglk)]
    #print('pk:')
    #print(pk)
    pk_corn = pk #if vary_fund == True else np.delete(pk, [0,2])
    #y_0 range needs some messaging to make the plot. But in order to make the whole picture consistent, better change the range of y_1 too.
    #if vary_fund == False:
    #    samples_corn.T[-dim:] -= np.pi #This indeed changes samples_corn itself
    #    pk[-dim:] -= np.pi
    
    #print('pkFalse:')
    #print(pk)
        
    #print(pk) 
    #Now calculate median (50-percentile) value
    median = np.median(samples_corn, axis=0)
    #print(samples)
    #print(samples_corn)
    
    figcorn = corner.corner(samples_corn, bins = numbins, hist_bin_factor = 5, color = datacolor, truths=pk_corn, truth_color = pkcolor, plot_contours = True, labels = paramlabels, quantiles=(0.05, 0.16, 0.5, 0.84, 0.95), levels=[1-np.exp(-0.5), 1-np.exp(-1.64 ** 2/2)], show_titles=True)
    
    
    #Extract the axes in order to add more important line plots
    naxes = len(pk_corn)
    axes = np.array(figcorn.axes).reshape((naxes, naxes))
    
    # Loop over the diagonal
    for i in range(naxes):
        ax = axes[i, i]
        ax.axvline(median[i], color=mediancolor)
    
    # Loop over the histograms
    for yi in range(naxes):
        for xi in range(yi):
            ax = axes[yi, xi]
            ax.axvline(median[xi], color=mediancolor)
            ax.axhline(median[yi], color=mediancolor)
            ax.plot(median[xi], median[yi], color = mediancolor, marker = 's')
            
    
    figcorn.savefig(corner_file, format = 'png', dpi = 384, bbox_inches = 'tight')