#!/usr/bin/env python

import sys
import os
import string
import numpy as np
import h5py
import pylab as plt
from ncutils import read_nc
from scipy.interpolate import interp1d

dir_pyartdeco = os.environ['DIR_PYARTDECO']
sys.path.append(dir_pyartdeco+'tools')
import pyartdeco_utils as ad

dir_artdeco = os.environ['DIR_ARTDECO']

dir_opt  = dir_artdeco+'lib/opt/'
dir_data = dir_opt

from scipy.integrate import simps

delta_sumf11 = 1e-3

baum_dir  = '/rfs/proj/optical_properties/baum_opt/'
cdffile   = 'GeneralHabitMixture_SeverelyRough_AllWavelengths_FullPhaseMatrix.nc'
ang_s = read_nc(baum_dir+cdffile,'phase_angles')
ang_s = ang_s.astype(np.float64)
ang_store = np.append(ang_s[ang_s <= 80.0], np.linspace(80.5, 175.5, num=191, endpoint=True))
ang_store = np.append(ang_store, ang_s[ang_s >= 176.0])
ang_store = np.sort(ang_store)[::-1] 
mu_store = np.cos(ang_store*np.pi/180.0)
nang_store = len(ang_store)


#########################################################################################################



def doit(ptcle):

    wvl  = np.array([0.1, 100.])
    nwvl = len(wvl)

    Cext = np.zeros(nwvl)
    ssa  = np.zeros(nwvl)
    g    = np.zeros(nwvl)

    print 'Read opt ...'
    for iwvl in xrange(nwvl):

        ang_tmp, p11_tmp, p44_tmp, p21_tmp, p34_tmp, \
            Cext[iwvl], ssa[iwvl], g[iwvl] = ad.get_opt(dir_opt, ptcle, wvl[iwvl])

        if (iwvl==0):
            nang = len(ang_tmp)
            ang  = np.zeros(nang) 
            p11  = np.zeros((nang, nwvl))
            p44  = np.zeros((nang, nwvl))
            p21  = np.zeros((nang, nwvl))
            p34  = np.zeros((nang, nwvl))
            ang[:] = ang_tmp
            
        p11[:,iwvl] = p11_tmp    
        p44[:,iwvl] = p44_tmp    
        p21[:,iwvl] = p21_tmp    
        p34[:,iwvl] = p34_tmp    

        # check if angle def is the same for all wavelength and all particle size dist.
        if (nang!=len(ang_tmp)):
            print ''
            print 'There is a problem with nang'
            print ''
            exit(0)

        for iang in xrange(nang):
            if ang_tmp[iang]!=ang[iang]:
                print ''
                print 'There is a problem with ang'
                print ''
                exit(0)


    ### Legendre polynimial expansion with D-M truncation

    print 'Compute Legendre poly. exp. ...'
    keywords = '  betal_only  print_betal  print_recomp  verbose'
    nbetal    = np.array([8, 16, 32, 64], dtype = int)
    n_nbetal  = len(nbetal)
    nbetalmax = np.amax(nbetal)
    print 'nbetalmax=',nbetalmax  
    trunc_model = 'dm'
    saveroot = "get_betal_opac"

    trunc_coeff = np.zeros((n_nbetal, nwvl))

    alpha1 = np.zeros((nbetalmax+1, n_nbetal,  nwvl))
    alpha2 = np.zeros((nbetalmax+1, n_nbetal,  nwvl))
    alpha3 = np.zeros((nbetalmax+1, n_nbetal,  nwvl))
    alpha4 = np.zeros((nbetalmax+1, n_nbetal,  nwvl))
    beta1  = np.zeros((nbetalmax+1, n_nbetal,  nwvl))
    beta2  = np.zeros((nbetalmax+1, n_nbetal,  nwvl))

    p11_recomp = np.zeros((nang, n_nbetal,  nwvl))
    p22_recomp = np.zeros((nang, n_nbetal,  nwvl))
    p33_recomp = np.zeros((nang, n_nbetal,  nwvl))
    p44_recomp = np.zeros((nang, n_nbetal,  nwvl))
    p21_recomp = np.zeros((nang, n_nbetal,  nwvl))
    p34_recomp = np.zeros((nang, n_nbetal,  nwvl))

    p11_recomp_integrate = np.zeros((n_nbetal,  nwvl))
    p11_integrate        = np.zeros(( nwvl))

    for iwvl in xrange(nwvl):
        p11_integrate[iwvl] = simps(p11[:,iwvl],  np.cos(ang*np.pi/180.0))

    for i_nbetal in range(n_nbetal):
        artdeco_in=[]
        artdeco_in.append('# \n')
        artdeco_in.append('# MAIN INPUT FILE FOR ARTDECO PROGRAM  \n')
        artdeco_in.append('#   \n')
        artdeco_in.append('######################## \n')
        artdeco_in.append('# keywords (Ex: none, verbose,...) \n')
        artdeco_in.append(keywords+' \n')
        artdeco_in.append('######################## \n')
        artdeco_in.append('# outfiles root name \n')
        artdeco_in.append(saveroot+' \n')
        artdeco_in.append('######################## \n')
        artdeco_in.append('# mode \n')
        artdeco_in.append(' mono \n')
        artdeco_in.append('#######################\n')
        artdeco_in.append('#     filters\n')
        artdeco_in.append('  none  \n')
        artdeco_in.append('######################## \n')
        artdeco_in.append('# Wavelengths (microns) \n')
        s= ''
        for iwl in xrange(nwvl):
            s = s+'  %.9f'%wvl[iwl]
        artdeco_in.append(s+'   \n') 
        artdeco_in.append('###################### \n')
        artdeco_in.append('# Particles \n')
        artdeco_in.append('# type           interp.       H-G         Tau_550          vertical distribution type     vdist parameters (km)  \n')
        artdeco_in.append(ptcle+'    no \n')
        artdeco_in.append('# \n')
        artdeco_in.append('########################## \n')
        artdeco_in.append('# Truncation method (none, dfit, DM, Potter) \n')
        artdeco_in.append(trunc_model+'    \n')
        artdeco_in.append('########################## \n')
        artdeco_in.append('# Number of Betal \n')
        artdeco_in.append('%4i'%nbetal[i_nbetal]+'   \n')
        artdeco_in.append('#  \n')

        # write artdeco_in.dat
        f = open(dir_artdeco+'input/artdeco_in_get_kokha_opt.dat', 'w')
        for j in xrange(len(artdeco_in)):
            f.write(artdeco_in[j])
        f.close()

        # run ARTDECO one time
        os.system(dir_artdeco.strip()+"src/artdeco artdeco_in_get_kokha_opt.dat")

        for iwvl in xrange(nwvl):

            alpha1_tmp, alpha2_tmp, alpha3_tmp, alpha4_tmp, beta1_tmp, beta2_tmp, nbetal_tmp, trunc_coef_tmp =\
                ad.get_betal(ptcle, wvl[iwvl], dir_artdeco+'out/'+saveroot+'/Betal_'+ptcle+'_'+trunc_model+'.dat')

            if nbetal_tmp!=nbetal[i_nbetal]:
                print ''
                print 'There is a problem with nbetal'
                print ''
                exit(0)

            alpha1[0:nbetal_tmp+1, i_nbetal,  iwvl] = alpha1_tmp
            alpha2[0:nbetal_tmp+1, i_nbetal,  iwvl] = alpha2_tmp
            alpha3[0:nbetal_tmp+1, i_nbetal,  iwvl] = alpha3_tmp
            alpha4[0:nbetal_tmp+1, i_nbetal,  iwvl] = alpha4_tmp
            beta1[0:nbetal_tmp+1, i_nbetal,  iwvl]  = beta1_tmp
            beta2[0:nbetal_tmp+1, i_nbetal,  iwvl]  = beta2_tmp
            trunc_coeff[i_nbetal,  iwvl]            = trunc_coef_tmp

            ang_tmp, p11_tmp, p22_tmp, p33_tmp, p44_tmp, p21_tmp, p34_tmp, \
                Cext_tmp, ssa_tmp, g_tmp,                                  \
                nbetal_tmp, nteta_tmp, trunc_coef_tmp, sumf11_tmp =        \
                ad.get_opt("toto", ptcle, wvl[iwvl],  recomp = dir_artdeco+'out/'+saveroot+'/opt_recomp_'+ptcle+'_'+trunc_model+'.dat')

            # check if angle def is the same for all wavelength and all particle size dist.
            if (nang!=len(ang_tmp)):
                print ''
                print 'There is a problem with nang'
                print ''
                exit(0)

            for iang in xrange(nang):
                if ang_tmp[iang]!=ang[iang]:
                    print ''
                    print 'There is a problem with ang'
                    print ''
                    exit(0)

            if  abs(sumf11_tmp - 2.00) > delta_sumf11 :
                print 'sumf11_tmp != 2.0, ', ptcle, nbetal[i_nbetal], sumf11_tmp 

            p11_recomp[:, i_nbetal,  iwvl] = p11_tmp    
            p22_recomp[:, i_nbetal,  iwvl] = p22_tmp    
            p33_recomp[:, i_nbetal,  iwvl] = p33_tmp    
            p44_recomp[:, i_nbetal,  iwvl] = p44_tmp    
            p21_recomp[:, i_nbetal,  iwvl] = p21_tmp    
            p34_recomp[:, i_nbetal,  iwvl] = p34_tmp                  

            p11_recomp_integrate[i_nbetal, iwvl] = simps(p11_tmp,  np.cos(ang*np.pi/180.0))
            print 'p11_recomp_integrate[i_nbetal,iwvl] = ', p11_recomp_integrate[i_nbetal, iwvl] 



    # re-interpolation des matrice de phase
    print "interpolate phase matrices on lighter angular grid"
    
    p11_store = np.zeros((nang_store,nwvl))
    p44_store = np.zeros((nang_store,nwvl))
    p21_store = np.zeros((nang_store,nwvl))
    p34_store = np.zeros((nang_store,nwvl))

    p11_recomp_store = np.zeros((nang_store,n_nbetal,nwvl))
    p22_recomp_store = np.zeros((nang_store,n_nbetal,nwvl))
    p33_recomp_store = np.zeros((nang_store,n_nbetal,nwvl))
    p44_recomp_store = np.zeros((nang_store,n_nbetal,nwvl))
    p21_recomp_store = np.zeros((nang_store,n_nbetal,nwvl))
    p34_recomp_store = np.zeros((nang_store,n_nbetal,nwvl))


    for iwl in xrange(nwvl):

        fp11      = interp1d(ang,p11[:,iwl])
        p11_store[:,iwl] = fp11(ang_store)

        fp44      = interp1d(ang,p44[:,iwl])
        p44_store[:,iwl] = fp44(ang_store)

        fp21      = interp1d(ang,p21[:,iwl])
        p21_store[:,iwl] = fp21(ang_store)

        fp34      = interp1d(ang,p34[:,iwl])
        p34_store[:,iwl] = fp34(ang_store)

        for i_nbetal in xrange(n_nbetal):

            fp11      = interp1d(ang,p11_recomp[:,i_nbetal,iwl])
            p11_recomp_store[:,i_nbetal,iwl] = fp11(ang_store)

            fp22      = interp1d(ang,p22_recomp[:,i_nbetal,iwl])
            p22_recomp_store[:,i_nbetal,iwl] = fp22(ang_store)

            fp33      = interp1d(ang,p33_recomp[:,i_nbetal,iwl])
            p33_recomp_store[:,i_nbetal,iwl] = fp33(ang_store)

            fp44      = interp1d(ang,p44_recomp[:,i_nbetal,iwl])
            p44_recomp_store[:,i_nbetal,iwl] = fp44(ang_store)

            fp21      = interp1d(ang,p21_recomp[:,i_nbetal,iwl])
            p21_recomp_store[:,i_nbetal,iwl] = fp21(ang_store)

            fp34      = interp1d(ang,p34_recomp[:,i_nbetal,iwl])
            p34_recomp_store[:,i_nbetal,iwl] = fp34(ang_store)


    # Normalisation de Cext
    Cext_norm  = Cext

    
    # print nang_store
    # plt.semilogy(ang_store, p11_store[:,0], "+")
    # plt.show()


    ############################
    # WRITE A RESULT FILE (HDF5)

    filename = dir_data+'opt_kokha.h5'

    f = h5py.File(filename, 'a')    

    grp    = f.require_group(ptcle)

    subgrp = grp.require_group("axis")

    dset = subgrp.require_dataset("wavelengths", (nwvl,), dtype='float64')
    dset[...] = wvl

    dset = subgrp.require_dataset("phase_angles", (nang_store,), dtype='float64')
    dset[...] = ang_store

    dset = subgrp.require_dataset("mu", (nang_store,), dtype='float64')
    dset[...] = np.cos(ang_store*np.pi/180.0)

    dset = subgrp.require_dataset("nbetal", (n_nbetal,), dtype='i')
    dset[...] = nbetal    

    dset = subgrp.require_dataset("ibetal", ( nbetalmax+1 ,), dtype='i')
    dset[...] = np.arange(nbetalmax+1)    




    subgrp = grp.require_group("data")

    dset = subgrp.require_dataset("single_scattering_albedo", (nwvl,), dtype='float64')
    dset[...] = ssa
    dset.attrs.create("dimensions", "wavelengths")

    dset = subgrp.require_dataset("extinction_coeff", (nwvl,), dtype='float64')
    dset[...] = Cext 
    dset.attrs.create("dimensions", "wavelengths")

    dset = subgrp.require_dataset("normed_ext_coeff", (nwvl,), dtype='float64')
    dset[...] = Cext_norm 
    dset.attrs.create("dimensions", "wavelengths")

    dset = subgrp.require_dataset("integrate_p11", (nwvl,), dtype='float64')
    dset[...] = p11_integrate
    dset.attrs.create("dimensions", "wavelengths")

    dset = subgrp.require_dataset("p11_phase_function", (nang_store,nwvl), dtype='float64')
    dset[...] = p11_store
    dset.attrs.create("dimensions", "mu,wavelengths")

    dset = subgrp.require_dataset("p21_phase_function", (nang_store,nwvl), dtype='float64')
    dset[...] = p21_store
    dset.attrs.create("dimensions", "mu,wavelengths")

    dset = subgrp.require_dataset("p34_phase_function", (nang_store,nwvl), dtype='float64')
    dset[...] = p34_store
    dset.attrs.create("dimensions", "mu,wavelengths")

    dset = subgrp.require_dataset("p44_phase_function", (nang_store,nwvl), dtype='float64')
    dset[...] = p44_store
    dset.attrs.create("dimensions", "mu,wavelengths")

    dset = subgrp.require_dataset("truncation_coefficient", (n_nbetal,nwvl), dtype='float64')
    dset[...] = trunc_coeff
    dset.attrs.create("dimensions", "nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("integrate_p11_recomp", (n_nbetal,nwvl), dtype='float64')
    dset[...] = p11_recomp_integrate
    dset.attrs.create("dimensions", "nbetal,wavelengths")

    dset = subgrp.require_dataset("alpha1_betal", (nbetalmax+1, n_nbetal,nwvl), dtype='float64')
    dset[...] = alpha1
    dset.attrs.create("dimensions", "ibetal,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")
    dset = subgrp.require_dataset("alpha2_betal", (nbetalmax+1, n_nbetal,nwvl), dtype='float64')
    dset[...] = alpha2
    dset.attrs.create("dimensions", "ibetal,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")
    dset = subgrp.require_dataset("alpha3_betal", (nbetalmax+1, n_nbetal,nwvl), dtype='float64')
    dset[...] = alpha3
    dset.attrs.create("dimensions", "ibetal,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")
    dset = subgrp.require_dataset("alpha4_betal", (nbetalmax+1, n_nbetal,nwvl), dtype='float64')
    dset[...] = alpha4
    dset.attrs.create("dimensions", "ibetal,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")
    dset = subgrp.require_dataset("beta1_betal", (nbetalmax+1, n_nbetal,nwvl), dtype='float64')
    dset[...] = beta1
    dset.attrs.create("dimensions", "ibetal,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")
    dset = subgrp.require_dataset("beta2_betal", (nbetalmax+1, n_nbetal,nwvl), dtype='float64')
    dset[...] = beta2
    dset.attrs.create("dimensions", "ibetal,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("recomposed_p11_phase_function", (nang_store, n_nbetal,nwvl), dtype='float64')
    dset[...] = p11_recomp_store
    dset.attrs.create("dimensions", "mu,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("recomposed_p22_phase_function", (nang_store, n_nbetal,nwvl), dtype='float64')
    dset[...] = p22_recomp_store
    dset.attrs.create("dimensions", "mu,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("recomposed_p33_phase_function", (nang_store, n_nbetal,nwvl), dtype='float64')
    dset[...] = p33_recomp_store
    dset.attrs.create("dimensions", "mu,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("recomposed_p44_phase_function", (nang_store, n_nbetal,nwvl), dtype='float64')
    dset[...] = p44_recomp_store
    dset.attrs.create("dimensions", "mu,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("recomposed_p21_phase_function", (nang_store, n_nbetal,nwvl), dtype='float64')
    dset[...] = p21_recomp_store
    dset.attrs.create("dimensions", "mu,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    dset = subgrp.require_dataset("recomposed_p34_phase_function", (nang_store, n_nbetal,nwvl), dtype='float64')
    dset[...] = p34_recomp_store
    dset.attrs.create("dimensions", "mu,nbetal,wavelengths")
    dset.attrs.create("truncation", "dm")

    f.close


    return



if __name__=='__main__':
      
    doit('kokha_aer')
    doit('kokha_cl')



