#!/usr/bin/env python
from __future__ import print_function, division, absolute_import
from builtins import range

import sys
import os
import string
import numpy as np
import h5py
import xarray as xr
from scipy.interpolate import interp1d
from ncutils import read_nc

import matplotlib.pyplot as plt

import trunc

wc_sol_path = "/home/mathieu/work/RTM/libradtran/libRadtran-2.0.1/data/wc/mie/wc.sol.mie.cdf"
wc_trm_path = "/home/mathieu/work/RTM/libradtran/libRadtran-2.0.1/data/wc/mie/wc.trm.mie.cdf"


baum_dir  = '/rfs/data/baum_opt/'
cdffile   = 'GeneralHabitMixture_SeverelyRough_AllWavelengths_FullPhaseMatrix.nc'
ang_s = read_nc(baum_dir+cdffile,'phase_angles')
ang_s = ang_s.astype(np.float64)
ang = np.append(ang_s[ang_s <= 80.0], np.linspace(80.5, 175.5, num=191, endpoint=True))
ang = np.append(ang, ang_s[ang_s >= 176.0])
ang = np.sort(ang)[::-1] 
mu = np.cos(ang*np.pi/180.0)
nang = len(ang)


rho_liq_wat = 1.0e-12 # g micron-3

######################################################################

def alpha_b(reff, veff):
    alpha = 1./veff - 3.
    b     = 1./(reff*veff)    
    return alpha, b

def reff_veff(alpha, b):
    veff = 1.0 / (alpha+3.0)
    reff = 1.0 / (b*veff)
    return reff, veff


# def gamma_size_dist(r, b, alpha):
#     #from scipy.special import gammaln
#     from scipy.special import loggamma
#     alpha1= alpha+1.0
#     logC  = alpha1 * np.log(b) - loggamma(alpha1)
#     n = np.exp( logC + alpha * np.log(r) - b * r)
#     return n


def gamma_size_dist(r, reff, veff):

    # as defined in mie3

    # reff is the effective radius, reff must be positive
    # veff is the effective variance, veff must be positive and less than 0.5
    # if veff is larger than 1/3, n(r) is singular at r = 0

    #from scipy.special import gammaln
    from scipy.special import loggamma

    alpha = 1.0/veff - 3.0
    b     = 1.0/(veff*reff)
    alpha1= alpha+1.0
    #logC  = alpha1 * np.log(b) - gammaln(alpha1)
    logC  = alpha1 * np.log(b) - loggamma(alpha1)

    n = np.exp( logC + alpha * np.log(r) - b * r)

    return n

######################################################################




def doit(nbetal, res_suffixe=""):


    res_file = "/rfs/proj/artdeco_lib/opt/opt_libradtran_liquid.h5"

    sol_prop = xr.open_dataset(wc_sol_path)
    trm_prop = xr.open_dataset(wc_trm_path)

    # print(sol_prop["wavelen"])
    # print(trm_prop["wavelen"])
    
    # wavelen  (nlam) float64 ...
    # reff     (nreff) float64 ...
    # theta    (nlam, nreff, nphamat, nthetamax) float32 ...
    # ntheta   (nlam, nreff, nphamat) int32 ...
    # phase    (nlam, nreff, nphamat, nthetamax) float32 ...
    # pmom     (nlam, nreff, nphamat, nmommax) float32 ...
    # nmom     (nlam, nreff) int32 ...
    # ext      (nlam, nreff) float64 ...
    # ssa      (nlam, nreff) float64 ...
    # refre    (nlam) float64 ...
    # refim    (nlam) float64 ...
    # rho      (nrho) float64 ...

    #print(sol_prop)
    #print(sol_prop["ssa"][...])

    
    wvl  = np.append(sol_prop["wavelen"], trm_prop["wavelen"])
    
    reff = np.copy(sol_prop["reff"])
    
    Sext_o_WC = np.append(sol_prop["ext"], trm_prop["ext"], axis=0 ) # km-1 / (g m-3)

    Sext_o_WC = Sext_o_WC / 1000. # m-1 / (g m-3) = m2 / g
    
    ssa       = np.append(sol_prop["ssa"], trm_prop["ssa"], axis=0 )

    Sext = np.zeros_like(Sext_o_WC)

    alpha = 7.0
    veff  =  1.0 / (alpha+3.0)

    r4vol = np.linspace(1e-3, 2000., num=2000)

    for ireff in range(len(reff)):
        n_r =  gamma_size_dist(r4vol, reff[ireff], veff)
        Vol =  np.trapz( n_r * 4./3.*np.pi*pow(r4vol,3.0), x=r4vol) / np.trapz( n_r, x=r4vol)            
        Sext[:,ireff] = Sext_o_WC[:,ireff] * (Vol*rho_liq_wat) / 1e-12 # micron**2

    
    p11_integrate = np.zeros( (len(wvl), len(reff)) )
    p11 = np.zeros( (len(wvl), len(reff), len(ang)) )
    p44 = np.zeros( (len(wvl), len(reff), len(ang)) )
    p21 = np.zeros( (len(wvl), len(reff), len(ang)) )
    p34 = np.zeros( (len(wvl), len(reff), len(ang)) )

    ntheta = np.copy(sol_prop["ntheta"])
    phase  = np.copy(sol_prop["phase"])
    theta  = np.copy(sol_prop["theta"])   
    for ireff in range(len(reff)):
        print( reff[ireff] )
        for iwvl in range(len(  sol_prop["wavelen"]  )):

            imat = 0
            p11[iwvl, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)
            imat = 1
            p21[iwvl, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)
            imat = 2
            p44[iwvl, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)
            imat = 3
            p34[iwvl, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)

            
            
            
    nwvl_shift = len(sol_prop.wavelen)
    ntheta = np.copy(trm_prop["ntheta"])
    phase  = np.copy(trm_prop["phase"])
    theta  = np.copy(trm_prop["theta"])   
    for ireff in range(len(reff)):
        print( reff[ireff] )
        for iwvl in range(len(  trm_prop["wavelen"]  )):

            imat = 0
            p11[iwvl+nwvl_shift, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)
            imat = 1
            p21[iwvl+nwvl_shift, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)
            imat = 2
            p44[iwvl+nwvl_shift, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)
            imat = 3
            p34[iwvl+nwvl_shift, ireff,:] = interp1d( theta[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ],  phase[iwvl, ireff, imat, 0:ntheta[iwvl, ireff, imat] ] )(ang)


    p11_integrate[:,:] = np.trapz( p11[:,:,:], x=mu, axis=2) 


    ##############
    #   Betal 


    print(" compute Betal...")
    
    nbetalmax    = np.max(nbetal)
    n_nbetal     = len(nbetal)

    alpha1       = np.full( (len(wvl), len(reff), nbetalmax+1, n_nbetal),np.nan)
    alpha2       = np.full( (len(wvl), len(reff), nbetalmax+1, n_nbetal),np.nan)
    alpha3       = np.full( (len(wvl), len(reff), nbetalmax+1, n_nbetal),np.nan)
    alpha4       = np.full( (len(wvl), len(reff), nbetalmax+1, n_nbetal),np.nan)
    beta1        = np.full( (len(wvl), len(reff), nbetalmax+1, n_nbetal),np.nan)
    beta2        = np.full( (len(wvl), len(reff), nbetalmax+1, n_nbetal),np.nan)

    trunc_coeff  = np.full( (len(wvl), len(reff), n_nbetal),np.nan)

    p11_recomp   =  np.full( (len(wvl), len(reff), len(ang), n_nbetal),np.nan)
    p22_recomp   =  np.full( (len(wvl), len(reff), len(ang), n_nbetal),np.nan)
    p33_recomp   =  np.full( (len(wvl), len(reff), len(ang), n_nbetal),np.nan)
    p44_recomp   =  np.full( (len(wvl), len(reff), len(ang), n_nbetal),np.nan)
    p21_recomp   =  np.full( (len(wvl), len(reff), len(ang), n_nbetal),np.nan)
    p34_recomp   =  np.full( (len(wvl), len(reff), len(ang), n_nbetal),np.nan)

    p11_recomp_integrate  =  np.full( (len(wvl), len(reff), n_nbetal),np.nan)
    
    F        = np.zeros((6,len(ang)), order="F")
    F_recomp = np.zeros((6,len(ang)), order="F")
    coeftr   = np.zeros((1), order="F")
    # F(1) = F11
    # F(2) = F22
    # F(3) = F33
    # F(4) = F44
    # F(5) = F12
    # F(6) = F34

    gauss_mu, gauss_wght = np.polynomial.legendre.leggauss(len(ang))
    p11_gauss = interp1d( mu,  p11[:,:,:], axis=2)(gauss_mu)
    p21_gauss = interp1d( mu,  p21[:,:,:], axis=2)(gauss_mu)
    p44_gauss = interp1d( mu,  p44[:,:,:], axis=2)(gauss_mu)
    p34_gauss = interp1d( mu,  p34[:,:,:], axis=2)(gauss_mu)


    for ib,nb in enumerate(nbetal):

        print(nb)
        
        betal  = np.zeros((6,nb+1), order="F")
        # betal(1) = alpha1
        # betal(2) = alpha2
        # betal(3) = alpha3
        # betal(4) = alpha4
        # betal(5) = beta1
        # betal(6) = beta2

        for ireff in range(len(reff)):
            
            print( reff[ireff] )
      
            for iwvl in range(len(wvl)):
                    
                F[0,:]=p11_gauss[iwvl,ireff,:]
                F[1,:]=p11_gauss[iwvl,ireff,:]
                F[2,:]=p44_gauss[iwvl,ireff,:]
                F[3,:]=p44_gauss[iwvl,ireff,:]
                F[4,:]=p21_gauss[iwvl,ireff,:]
                F[5,:]=p34_gauss[iwvl,ireff,:]
                                        
                trunc.mtrunc.trunc_dm(4, len(ang), gauss_mu, gauss_wght, F, nb, betal, coeftr)
                F_recomp[:,:] = 0.0
                trunc.mtrunc.get_f_recomp(4, nb, len(ang), mu, betal, F_recomp)
                 
                alpha1[iwvl,ireff,:nb+1,ib]  = betal[0,:]
                alpha2[iwvl,ireff,:nb+1,ib]  = betal[1,:]
                alpha3[iwvl,ireff,:nb+1,ib]  = betal[2,:]
                alpha4[iwvl,ireff,:nb+1,ib]  = betal[3,:]
                beta1[iwvl,ireff,:nb+1,ib]   = betal[4,:]
                beta2[iwvl,ireff,:nb+1,ib]   = betal[5,:]
                
                p11_recomp[iwvl,ireff,:,ib]  = F_recomp[0,:]
                p22_recomp[iwvl,ireff,:,ib]  = F_recomp[1,:]
                p33_recomp[iwvl,ireff,:,ib]  = F_recomp[2,:]
                p44_recomp[iwvl,ireff,:,ib]  = F_recomp[3,:]
                p21_recomp[iwvl,ireff,:,ib]  = F_recomp[4,:]
                p34_recomp[iwvl,ireff,:,ib]  = F_recomp[5,:]

                trunc_coeff[iwvl,ireff,ib] = coeftr[0]

    p11_recomp_integrate[:,:,:] = np.trapz( p11_recomp[:,:,:,:] , x=mu, axis=2)


    f = h5py.File(res_file, "a")

    grp = f.require_group("libradtran_liquid")
    
    subgrp = grp.require_group("axis")

    dset = subgrp.create_dataset("wavelengths", data=wvl, dtype='float64')
    
    dset = subgrp.create_dataset("reff", data=reff, dtype='float64')

    dset = subgrp.create_dataset("phase_angles", data=ang, dtype='float64')

    dset = subgrp.create_dataset("mu", data=mu, dtype='float64')

    dset = subgrp.create_dataset("nbetal", data=nbetal, dtype='i')

    dset = subgrp.create_dataset("ibetal", data=np.arange(nbetalmax+1), dtype='i')


    subgrp = grp.require_group("data")

    dset = subgrp.create_dataset("single_scattering_albedo", data=np.swapaxes(ssa,0,1), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("reff,wavelengths"))

    dset = subgrp.create_dataset("Cext", data=np.swapaxes(Sext,0,1), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("reff,wavelengths"))
    dset.attrs.create("unit",  np.string_("micron^2"))

    dset = subgrp.create_dataset("extinction_coefficient_over_wc", data=np.swapaxes(Sext_o_WC,0,1), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("reff,wavelengths"))
    dset.attrs.create("unit",  np.string_("m^2g-1"))

    dset = subgrp.create_dataset("integrate_p11", data = np.swapaxes(p11_integrate,0,1), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("reff,wavelengths"))

    dset = subgrp.create_dataset("p11_phase_function", data=np.swapaxes(p11,0,2), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,reff,wavelengths"))

    dset = subgrp.create_dataset("p44_phase_function", data=np.swapaxes(p44,0,2), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,reff,wavelengths"))

    dset = subgrp.create_dataset("p21_phase_function", data=np.swapaxes(p21,0,2), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,reff,wavelengths"))

    dset = subgrp.create_dataset("p34_phase_function", data=np.swapaxes(p34,0,2), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,reff,wavelengths"))

    
    dset = subgrp.create_dataset("integrate_p11_recomp", data=np.swapaxes(p11_recomp_integrate,0,2), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("nbetal,reff,wavelengths"))

    dset = subgrp.create_dataset("truncation_coefficient", data=np.swapaxes(trunc_coeff,0,2), dtype='float32')
    dset.attrs.create("dimensions",  np.string_("nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))


    alpha1 = np.rollaxis(np.rollaxis(alpha1,0,4),0,3)
    alpha2 = np.rollaxis(np.rollaxis(alpha2,0,4),0,3)
    alpha3 = np.rollaxis(np.rollaxis(alpha3,0,4),0,3)
    alpha4 = np.rollaxis(np.rollaxis(alpha4,0,4),0,3)
    beta1 = np.rollaxis(np.rollaxis(beta1,0,4),0,3)
    beta2 = np.rollaxis(np.rollaxis(beta2,0,4),0,3)
    
    dset = subgrp.create_dataset("alpha1_betal", data=alpha1, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))
    
    dset = subgrp.create_dataset("alpha2_betal", data=alpha2, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("alpha3_betal", data=alpha3, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("alpha4_betal", data=alpha4, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("beta1_betal", data=beta1, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("beta2_betal", data=beta2, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("ibetal,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))


    p11_recomp = np.rollaxis(np.rollaxis(p11_recomp,0,4),0,3)
    p22_recomp = np.rollaxis(np.rollaxis(p22_recomp,0,4),0,3)
    p33_recomp = np.rollaxis(np.rollaxis(p33_recomp,0,4),0,3)
    p44_recomp = np.rollaxis(np.rollaxis(p44_recomp,0,4),0,3)
    p34_recomp = np.rollaxis(np.rollaxis(p34_recomp,0,4),0,3)
    p21_recomp = np.rollaxis(np.rollaxis(p21_recomp,0,4),0,3)
    
    dset = subgrp.create_dataset("recomposed_p11_phase_function", data=p11_recomp, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("recomposed_p22_phase_function", data=p22_recomp, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("recomposed_p33_phase_function", data=p33_recomp, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("recomposed_p44_phase_function", data=p44_recomp, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("recomposed_p21_phase_function", data=p21_recomp, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))

    dset = subgrp.create_dataset("recomposed_p34_phase_function", data=p34_recomp, dtype='float32')
    dset.attrs.create("dimensions",  np.string_("mu,nbetal,reff,wavelengths"))
    dset.attrs.create("truncation",  np.string_("dm"))
    
    f.close

    



                
         
    
    return
    
    

if __name__=='__main__':
      

    nbetal    = np.array([4, 8, 16, 32], dtype = int)
    
    doit( nbetal, res_suffixe="")
    


