#!/usr/bin/env python
from __future__ import print_function, division, absolute_import
from builtins import range

import sys
import os
import string
import numpy as np
import h5py
from ncutils import read_nc
from scipy.interpolate import interp1d
from itertools import product

import matplotlib.pyplot as plt

import trunc

dir_pyartdeco  = "/home/mathieu/work/RTM/artdeco/pyartdeco/"
dir_artdeco    = "/home/mathieu/work/RTM/artdeco/fortran/"
os.environ["DIR_PYARTDECO"] = dir_pyartdeco
os.environ["DIR_ARTDECO"]   = dir_artdeco

dir_opt  = '/rfs/proj/artdeco_lib/opt/'

trunc_model ='dm'

baran_coef_files = "/home/mathieu/work/RTM/baran/GIANLUCA/database/ens_tot_sol_ir_opt_prop_m_0_0257_f07T_013_nophase_param.H5"



def get_baran_coef(): 

    fb = h5py.File(baran_coef_files,"r")
    wvl      = np.copy(fb['WAVELENGTH'])
    sca_coef = np.copy(fb['SCATTERING_COEFFICIENTS'])
    g_coef   = np.copy(fb['G_COEFFICIENTS'])
    abs_coef = np.copy(fb['ABSORPTION_COEFFICIENTS'])
    fb.close()

    return wvl,sca_coef,g_coef,abs_coef





def baran_phase(g, ang, mu, rad):
   
    # phase[nang]
    
    phase = np.zeros(len(ang))
    
    aa = 1.0 - g*g

    if ( g < 0.2 ) and ( g >= 0.0 ) :

            bb    =(1.0+g*g-2.0*g*mu)**1.5        
            phase = aa/bb
            
    elif (g >= 0.2) and (g < 0.7):

        if (g >= 0.6) :

            alpha = 1.0 / np.sqrt(1.095*g)
            
        elif (g >= 0.45) and (g < 0.6):

            alpha = 1.0 / np.sqrt(1.23*g)
    
        elif (g >= 0.3) and (g < 0.45):

            alpha = 1.0 / np.sqrt(1.5*g)

        else: 
            
            alpha = (1.0 / np.sqrt(1.-g))*1.25

        bb = (1.0 + g*g - 2.0 *g*mu[ang<54.8]) ** 1.5
        phase[ang<54.8] = aa/bb*mu[ang<54.8]*alpha
        bb = (1.0 + g*g - 1.8 *g*mu[ang>=54.8]*np.sin(rad[ang>=54.8]) ) ** 1.5
        phase[ang>=54.8] = aa/bb
        
    else: # g>=0.7
        
        if (g >= 0.7) and (g < 0.8):
            
            norm  = 0.1481e3-0.2025e3*g+0.4949e2*g*g
            alpha = norm/np.sqrt(g)

        elif (g >= 0.8) and (g <= 0.9):

            norm=0.2771e3-0.5102e3*g+0.2329e3*g*g
            alpha = norm/np.sqrt(g)

        else: #g>=0.9

            norm=0.4219e3-0.8271e3*g+0.4063e3*g*g
            alpha = norm/np.sqrt(g)

        
        bb=(1.0+g*g-2.0*g*mu[ang<=3.0])**1.5
        phase[ang<=3.0] = aa/bb * mu[ang<=3.0]**128.0*alpha
        
        bb=(1.0+g*g-2.0*g*np.cos(1.3*rad[(ang>3.0) & (ang<30.0)]))**1.2 
        phase[(ang>3.0) & (ang<30.0)] = aa/bb*mu[(ang>3.0) & (ang<30.0)]

        dd = ((1.0-g)/4.6)+g
        if (g >= 0.7) and (g < 0.9):
            beta = 0.68
        elif (g >= 0.9):
            beta = 0.71
        else:
            beta = np.nan
        bb = (1.0+g*g-2.0*g*np.cos(dd*rad[(ang>=30.0) & (ang<54.8)]))**beta
        phase[(ang>=30.0) & (ang<54.8)] = aa/bb*mu[(ang>=30.0) & (ang<54.8)]

        bb = (1.0+g*g-1.5*g*mu[(ang>=54.8) & (ang<=95.0)]*np.sin(rad[(ang>=54.8) & (ang<=95.0)]))**1.5
        phase[(ang>=54.8) & (ang<=95.0)] = aa/bb

        phase[(ang>95.0)] = aa/bb[-1]
        
    return phase






def get_ext_omega0_asym(logiwc,temp,coef_sca,coef_g,coef_abs):

    # print(logiwc,temp)
    
    #########################################################################
    #     get the coef (km^-1)
    ##########################################################################
    absi = coef_abs[:,0] + coef_abs[:,1]*temp + coef_abs[:,2]*logiwc + coef_abs[:,3]*temp*temp + coef_abs[:,4]*logiwc*logiwc + coef_abs[:,5]*temp*logiwc
    absi = pow(10.0, absi)
    
    scai = coef_sca[:,0] + coef_sca[:,1]*temp +coef_sca[:,2]*logiwc + coef_sca[:,3]*temp*temp + coef_sca[:,4]*logiwc*logiwc + coef_sca[:,5]*temp*logiwc
    scai = pow(10., scai)

    scai = scai / 1000. # m-1
    absi = absi / 1000. # m-1
    
    exti = scai + absi
    ssai = scai / exti

    asymi = coef_g[:,0] + coef_g[:,1]*temp + coef_g[:,2]*logiwc

    
    return exti, ssai, asymi











def get_baran_opt(logiwc,temp, ang, nbetal, res_file): 

    nmat = 1
    
    if os.path.isfile(res_file):
        print("")
        print(" result file %s already present"%res_file)
        print("")
        return
        
    ang = np.sort(ang)
    mu = np.cos(ang*np.pi/180.0)
    rad = ang*np.pi/180.0

    wvl,sca_coef,g_coef,abs_coef = get_baran_coef() 

    exti   = np.full((len(logiwc),len(temp), len(wvl)),np.nan)
    Sexti  = np.full((len(logiwc),len(temp), len(wvl)),np.nan)
    ssai   = np.full((len(logiwc),len(temp), len(wvl)),np.nan)
    asymi  = np.full((len(logiwc),len(temp), len(wvl)),np.nan)
    phase  = np.full((len(ang), len(logiwc),len(temp), len(wvl)),np.nan)
    integrate_p11 = np.full((len(logiwc),len(temp), len(wvl)),np.nan)
    
    for it in range(len(temp)):
        for iiwc in range(len(logiwc)):
            
            Sexti[iiwc,it,:], ssai[iiwc,it,:], asymi[iiwc,it,:] = get_ext_omega0_asym(logiwc[iiwc],temp[it],sca_coef,g_coef,abs_coef)
            exti[iiwc,it,:] = Sexti[iiwc,it,:]/ pow(10.,logiwc[iiwc]) # m2 g-1
            
            for iwvl in range(len(wvl)):
                phase[:,iiwc,it,iwvl] = baran_phase(asymi[iiwc,it,iwvl], ang, mu, rad)


    phase = phase[::-1,:,:,:]
    ang   = ang[::-1]
    mu    = mu[::-1]
    rad   = rad[::-1]
    
    integrate_p11[:,:,:] = np.trapz( phase , x=mu, axis=0)

    ##############
    #   Betal 

    nbetalmax   = np.max(nbetal)
    n_nbetal    = len(nbetal)
    alpha1      = np.full( (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl)),np.nan)
    trunc_coeff           =  np.full((n_nbetal, len(logiwc),len(temp), len(wvl)),np.nan)
    p11_recomp = np.full( (len(ang),  n_nbetal, len(logiwc),len(temp), len(wvl)),np.nan)
    p11_recomp_integrate  = np.full(( n_nbetal, len(logiwc),len(temp), len(wvl)), np.nan)

    gauss_mu, gauss_wght = np.polynomial.legendre.leggauss(len(ang))
    phase_gauss = interp1d(mu ,phase, axis=0)(gauss_mu)

    F        = np.zeros((6,len(ang)), order="F")
    F_recomp = np.zeros((6,len(ang)), order="F")
    coeftr   = np.zeros((1), order="F")

    for ib,nb in enumerate(nbetal):

        print(nb)
        
        betal  = np.zeros((6,nb+1), order="F")
        
        for it in range(len(temp)):
            for iiwc in range(len(logiwc)):
                for iwvl in range(len(wvl)):

                    F[0,:]=phase_gauss[:,iiwc,it,iwvl]
                    trunc.mtrunc.trunc_dm(nmat, len(ang), gauss_mu, gauss_wght, F, nb, betal, coeftr)
                    F_recomp[:,:] = 0.0
                    trunc.mtrunc.get_f_recomp(nmat, nb, len(ang), mu, betal, F_recomp)

                    alpha1[:nb+1,ib,iiwc,it,iwvl]  = betal[0,:]
                    p11_recomp[:,ib,iiwc,it,iwvl]  = F_recomp[0,:]
                    trunc_coeff[ib,iiwc,it,iwvl] = coeftr[0]

    p11_recomp_integrate[:,:,:,:] = np.trapz( p11_recomp , x=mu, axis=0)
               


                
    
    f = h5py.File(res_file, "a")

    grp    = f.require_group("baran")
    
    subgrp = grp.require_group("axis")

    dset = subgrp.require_dataset("wavelengths", (len(wvl),), dtype='float64')
    dset[...] = wvl[...]

    dset = subgrp.require_dataset("temp", (len(temp),), dtype='float64')
    dset.attrs.create("unit",  np.string_("K"))
    dset[...] = temp[...]

    dset = subgrp.require_dataset("logiwc", (len(logiwc),), dtype='float64')
    dset.attrs.create("unit",  np.string_("g m-3"))
    dset[...] = logiwc[...]

    dset = subgrp.require_dataset("mu", (len(ang),), dtype='float64')
    dset[...] = mu[...]

    dset = subgrp.require_dataset("nbetal", (n_nbetal,), dtype='i')
    dset[...] = nbetal    

    dset = subgrp.require_dataset("ibetal", ( nbetalmax+1 ,), dtype='i')
    dset[...] = np.arange(nbetalmax+1)    

    subgrp = grp.require_group("data")

    dset = subgrp.require_dataset("single_scattering_albedo", ( len(logiwc),len(temp), len(wvl), ), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("logiwc,temp,wavelengths"))
    dset[...] = ssai[...]

    dset = subgrp.require_dataset("extinction_coefficient_over_wc", (len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("logiwc,temp,wavelengths"))
    dset.attrs.create("unit",  np.string_("m^2g-1"))
    dset[...] = exti[...]

    dset = subgrp.require_dataset("Cext", ( len(logiwc),len(temp), len(wvl), ), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("logiwc,temp,wavelengths"))
    dset.attrs.create("unit",  np.string_("not_def"))
    dset[...] = exti[...]

    dset = subgrp.require_dataset("integrate_p11", ( len(logiwc),len(temp), len(wvl), ), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("logiwc,temp,wavelengths"))
    dset[...] = integrate_p11

    dset = subgrp.require_dataset("p11_phase_function", (len(ang), len(logiwc),len(temp), len(wvl), ), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,logiwc,temp,wavelengths"))
    dset[...] = phase[...]

    dset = subgrp.require_dataset("p21_phase_function", (len(ang), len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,logiwc,temp,wavelengths"))
 
    dset = subgrp.require_dataset("p34_phase_function", (len(ang), len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,logiwc,temp,wavelengths"))

    dset = subgrp.require_dataset("p22_phase_function", (len(ang), len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,logiwc,temp,wavelengths"))

    dset = subgrp.require_dataset("p33_phase_function", (len(ang), len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,logiwc,temp,wavelengths"))

    dset = subgrp.require_dataset("p44_phase_function", (len(ang), len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,logiwc,temp,wavelengths"))

    dset = subgrp.require_dataset("integrate_p11_recomp", (n_nbetal,len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("nbetal,logiwc,temp,wavelengths"))
    dset[...] = p11_recomp_integrate[...]
                                  
    dset = subgrp.require_dataset("truncation_coefficient", (n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))
    dset[...] = trunc_coeff[...]

    dset = subgrp.require_dataset("alpha1_betal", (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))
    dset[...] = alpha1[...]

    dset = subgrp.require_dataset("alpha2_betal", (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.require_dataset("alpha3_betal", (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.require_dataset("alpha4_betal", (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.require_dataset("beta1_betal", (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.require_dataset("beta2_betal", (nbetalmax+1, n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.require_dataset("recomposed_p11_phase_function", (len(ang), n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))
    dset[...] = p11_recomp[...]

    dset = subgrp.require_dataset("recomposed_p22_phase_function", (len(ang), n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation", np.string_("dm"))

    dset = subgrp.require_dataset("recomposed_p33_phase_function", (len(ang), n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation", np.string_("dm"))

    dset = subgrp.require_dataset("recomposed_p21_phase_function", (len(ang), n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation", np.string_("dm"))

    dset = subgrp.require_dataset("recomposed_p34_phase_function", (len(ang), n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation", np.string_("dm"))

    dset = subgrp.require_dataset("recomposed_p44_phase_function", (len(ang), n_nbetal, len(logiwc),len(temp), len(wvl),), dtype='float64')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,logiwc,temp,wavelengths"))
    dset.attrs.create("truncation", np.string_("dm"))


    f.close

         
    

    
    return






if __name__=='__main__':


    baum_dir  = '/rfs/data/baum_opt/'
    cdffile   = 'GeneralHabitMixture_SeverelyRough_AllWavelengths_FullPhaseMatrix.nc'
    ang = read_nc(baum_dir+cdffile,'phase_angles')
    ang = ang.astype(np.float64)
    
    nbetal    = np.array([4,8,16], dtype = int)
   
    logiwc = np.linspace(-9, 1, num=50) # g m-3
    temp   = np.linspace(160, 280, num=30) # K

    # print(logiwc)
    # print(temp)
    
    res_file = dir_opt+"opt_ice_baran_highres.h5"

    get_baran_opt(logiwc, temp, ang, nbetal, res_file)


