#!/usr/bin/env python

# This was developped by Mathieu Compiegne at HYGEOS

# This contains routine that are used
# to run ARTDECO f2py vesrion
# NB : not acting on the artdeco.so variables
#      but rather loading/setting variables
#      that will further be used by it

from __future__ import print_function, division, absolute_import
from builtins import range

import numpy as np
import string
import os
from scipy.interpolate import interp1d
from ncutils import read_nc
from luts import MLUT, LUT, Idx, read_mlut_hdf5
import h5py

import scipy.constants as cste

mh2o   = 18.01528 # g mol-1
mair   = 28.9644  # g mol-1 
mo3    = 47.9982  # g mol-1



# Avogadro constant (mol^-1)
na = 6.022140857e23
# Acceleration due to gravity (m/s^2, exact)
gravity = 9.80665
# Universal gas constant (J/mol/K)
rgc = 8.3144598
# Boltzman constant (J K-1)  http://physics.nist.gov February 2017
kb = 1.38064852e-23

rdry      = rgc / mair # Rspecific J g-1 K-1
eps       = mh2o / mair  
scale_eps = 1.0 - eps
omega     = 7292115E-11
eqrad  = 6378.137
flatt  = 3.3528107e-3
grave  = 9.7803267715
 

#phasemat_interp = 'cubic'
phasemat_interp = 'linear'

surf_interp     = 'linear'
depol_interp    = 'linear'
atm_interp      = 'linear'

limit_rh = 110.

##############################################

def skipcomment(f):
    pos = f.tell()
    tmp = f.readline()
    while tmp.startswith('#'):
        pos = f.tell()
        tmp = f.readline()
    f.seek(pos)

##############################################

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

##############################################

def get_dens( d_kgperkg, temp, pap, hum ):
    # d_kgperkg  kg/kg over moist air
    # temp       K
    # pap       Pa
    # hum       kg/kg over moist air
    rmoist = rdry * ( 1 +  (1-eps)/eps*hum )
    return d_kgperkg * pap / rmoist / temp # g m-3

def get_massratio( wc, temp, pap, hum ):
    # wc        g m-3  
    # temp      K
    # pap       Pa
    # hum       kg/kg over moist air
    rmoist = rdry * ( 1 +  (1-eps) / eps * hum )
    return wc * rmoist * temp / pap # kg/kg over moist air

##############################################


def get_rh(p,t,uair,uh2o,phase):
    if phase in ["liquid","liq"]:
        pps = 6.1078 * np.exp( 17.27 *(t[:] - 273.15) / ( ( t[:] - 273.15 ) + 237.3) ) # pression partielle saturante (mb ou hPa)
    elif phase == "ice":
        pps = 6.1078 * np.exp( 22.42 *(t[:] - 273.15) / ( ( t[:] - 273.15 ) + 272.4) ) # pression partielle saturante (mb ou hPa)
    ppv = uh2o[:] / uair[:] * p[:]
    rh  = ppv / pps * 100.        
    return rh


def get_u_from_rh(p,t,uair,rh,phase):
    if phase in ["liquid","liq"]:
        pps = 6.1078 * np.exp( 17.27 *(t[:] - 273.15) / ( ( t[:] - 273.15 ) + 237.3) ) # pression partielle saturante (mb ou hPa)
    elif phase == "ice":
        pps = 6.1078 * np.exp( 22.42 *(t[:] - 273.15) / ( ( t[:] - 273.15 ) + 272.4) ) # pression partielle saturante (mb ou hPa)
    ppv = rh * pps / 100.0 # pression partielle de vapeur                                                                     
    u   = uair[:] * ppv / p[:]
    return u


def get_t_from_rh_u(p,rh,uair,uh2o,phase):
    ppv = uh2o[:] / uair[:] * p[:]
    pps = ppv / rh * 100. 
    if phase in ["liquid","liq"]:
        t   = ( np.log(pps/6.1078)*237.3) / (17.27 - np.log(pps/6.1078)) + 273.15
    elif phase == "ice":
        t   = ( np.log(pps/6.1078)*272.4) / (22.42 - np.log(pps/6.1078)) + 273.15
    return t
    
##############################################


def moist_to_dry_ppmv(ppmv,h2o_ppmv_moist) :
    return ppmv / (1.0 - 1e-6*h2o_ppmv_moist)

def dry_to_moist_ppmv(ppmv,h2o_ppmv_moist) :
    return ppmv * (1.0 - 1e-6*h2o_ppmv_moist)


def get_hum(h2o_, h2o_moist=False, hum_moist=True):
    # h2o ppmv
    # hum kg/kg
    if h2o_moist:
        h2o = h2o_ / (1.0 - 1e-6*h2o_)
    else:
        h2o = h2o_
    a = 1e6 * mair / mh2o / h2o
    hum = 1.0 / (  a + 1 )
    if not hum_moist:
        hum = hum / (1.0-hum)
    return hum


def get_ozo(o3_, hum, o3_moist=False, ozo_moist=True):
    # o3   ppmv
    # hum  kg/kg over moist air
    if o3_moist:
        o3 = o3_ / (1.0 - 1e-6 * get_h2o_ppmv_moist(hum) ) # ppmv over dry air    
    else:
        o3 = o3_  
    ozo = o3 / 1e6 * (1.0-hum) * mo3 / mair
    if not ozo_moist:
        ozo = ozo / ( 1.0 - hum)
    return ozo
    

def get_h2o_ppmv_moist(hum, dry=False):
    # hum is kg/kg over moist air
    h2o_ppmv_dry = 1e6 *  hum / ( 1.0 - hum)  * mair / mh2o     # ppmv over dry air
    h2o_ppmv_moist = h2o_ppmv_dry / (1.0 + 1e-6 * h2o_ppmv_dry) # ppmv over moist air   
    if dry:
        return h2o_ppmv_dry
    else:
        return h2o_ppmv_moist
 

def get_o3_ppmv_moist(ozo, hum, h2o_ppmv_moist, dry=False):
    # ozo is kg/kg over moist air
    ozo_ppmv_dry   = 1e6 *  ozo / ( 1.0 - hum)  * mair / mo3        # ppmv over dry air
    ozo_ppmv_moist = ozo_ppmv_dry  * (1.0 - 1e-6 * h2o_ppmv_moist ) # ppmv over moist air 
    if dry:
        return ozo_ppmv_dry
    else:
        return ozo_ppmv_moist

##############################################

def calc_hgpl(lat, temp, pap, psurf, elevation, q):
    
    # Note :  pap is a pressure of moist air    
    #         q is a concentration of water vapour in ppmv over moist air
    #           (It is unused in the present version of the routine 
    #             because we compute dmair of moist air)        

    # ------- Compute height of pressure levels:levels above surface. ------------
    #        The height of pressure levels H is obtained by integrating             |
    #        the hydrostatic equation dPRES=-GRAVL(H)*DMAIR*dH between              |
    #        two adjacent pressure levels                                           |
    #                                                                               |
    #            -P2         - H2                                                   |
    #           |           |                                                       |
    #           | dP/DMAIR= | GRAVL(H)*dH                                           |
    #           |           |                                                       |
    #          -  P1       -   H1                                                   |
    #                                                                               |
    #        The integration of 1/DMAIR is carried out assuming DMAIR               |
    #        varies linearly with pressure. The integral is then                    |
    #        computed analytically.                                                 |
    #                                                                               |
    #        The value of the gravity as a function of altitude H can be            |
    #        expressed using the inverse-square law of gravitation:                 |
    #                                                                               |
    #        GRAVL(H)=GRAVL*(REARTH/(H+REARTH))**2=                                 |
    #                 GRAVL*(1-2*H/REARTH+3*H**2/REARTH**2+terms of higher order)   |
    #                                                                               |
    #        If we eliminate the second and higher order terms we can write:        |
    #                                                                               |
    #        GRAVL(H)=GRAVL-2*GRAVL*H/REARTH=GRAVL-GRAVH*H                          |
    #                                                                               |
    #        Note that RLH = GRAVL / GRAVH in the equations below                   |
    # ------------------------------------------------------------------------------

    # Calculate the earth's radius at latitude lat assuming the
    # earth is an ellipsoid of revolution
    dflat  = (1.0 - flatt) ** 2.0
    fac    = (omega ** 2.0 * (eqrad * 1000.)) / (grave)
    beta   = 5. * fac / 2. - flatt - 17. * fac * flatt / 14.
    eta    = flatt * (5. * fac - flatt) / 8.
    rearth = np.sqrt( (eqrad ** 2.0 * dflat) / ( np.sin(lat*np.pi/180.0) ** 2.0 + dflat * np.cos(lat*np.pi/180.0)** 2.0)) # km
    
    rlh = rearth * 5.e2 # DAR: used to be gravl / gravh

    # The value of earth's gravity at surface at latitude lat is
    # computed using the international gravity formula.

    gravl = grave * (1.0 + beta * np.sin(lat*np.pi/180.0) ** 2.0 + eta * (2.0 * np.sin(lat*np.pi/180.0) * np.cos(lat*np.pi/180.0)) ** 2.0)
    
    gravh_r = rearth / (2.e-3 * gravl) # computationally useful form

    int_dmair = np.zeros_like(pap)  # Integrated layer values for dmair (hPa/(kg/m3))
    dmair     = np.zeros_like(pap)  # Density of moist air kg/m3
    ppw       = np.zeros_like(pap)  # Partial pressure of water vapour (hPa)
    ztemp     = np.zeros_like(pap)  # 
    hgpl      = np.zeros_like(pap)  # Level geopotential height (km)

    c    = 1000. * rgc / (100. * mair)

    # Calculate partial pressure of water vapour (use ppmv wet)
    #ppw[:]   = pap[:] * q[:] * 1e-6
    #dmair[:] = (pap[:] - ppw[:] * scale_eps) / (c * temp[:])
    dmair[:] = pap[:] / (c * temp[:])
    
    #ppw_surf   = psurf* q[len(pap)-1] * 1e-6
    #dmair_surf = (psurf - ppw_surf * scale_eps) / (c * temp[len(pap)-1]) 
    dmair_surf = psurf / (c * temp[len(pap)-1])
    
    # Altitude of the first level
    dp = np.zeros_like(pap)
    for ilev in range(len(pap)-1):
        dp[ilev] =  pap[ilev] - pap[ilev+1]
    dp[len(pap)-1] = pap[len(pap)-1] - psurf

    ilev = len(pap)-1    
    int_dmair[ilev] = -dp[ilev] / (dmair_surf - dmair[ilev] ) * np.log(dmair_surf / dmair[ilev])
    ztemp = rlh**2.0 - elevation * (2. * rlh - elevation ) - 2. * int_dmair[ilev] * 1.e2 * gravh_r
    if ztemp > 0.0:
        hgpl[ilev] = (rlh - np.sqrt(ztemp))* 1.e-3
    else:
        ztemp = 1e-100
        hgpl[ilev] = (rlh )* 1.e-3
    
    for ilev in range(len(pap)-2, -1, -1):        
        int_dmair[ilev] = -dp[ilev] / (dmair[ilev+1] - dmair[ilev] ) * np.log(dmair[ilev+1] / dmair[ilev])
        ztemp = rlh**2.0 - 1.e3 * hgpl[ilev+1] * (2. * rlh - 1.e3 * hgpl[ilev+1] ) - 2. * int_dmair[ilev] * 1.e2 * gravh_r
        if ztemp > 0.0:
            hgpl[ilev] = (rlh - np.sqrt(ztemp))* 1.e-3
        else:
            ztemp = 1e-100
            hgpl[ilev] = (rlh )* 1.e-3

    return hgpl*1.0e3, dmair/100. # m, kg/m3


##############################################################################################################

class solar_irradiance(object):
    '''
    '''
    def __init__(self, artdeco_in, file_path="none", file_format="ascii", channel_list=[], solar_spec="none"):

        #print "Set solar_irradiance class..."
        self.solar_spec =solar_spec
        
        is_sorted = lambda a: np.all(a[:-1] <= a[1:])        
        
        if (artdeco_in.mode == 'kdis') and (file_path != 'none'):

            self.kdis_solrad = False   # init value

            if file_format == 'ascii':

                # check if the kdis solrad exists already
                if os.path.isfile(file_path):

                    f = open(file_path,'r')
                    skipcomment(f)
                    tmp = f.readline()
                    self.kdis_solrad_solcste = float(tmp.split()[0])                    
                    skipcomment(f)
                    tmp = f.readline()
                    self.kdis_solrad_nwvl = int(tmp.split()[0])
                    self.kdis_solrad_channels = np.arange(self.kdis_solrad_nwvl)+1
                    names = ['wl','f']
                    self.kdis_solrad_wvl = np.zeros(self.kdis_solrad_nwvl)
                    self.kdis_solrad_F0  = np.zeros(self.kdis_solrad_nwvl)
                    skipcomment(f)
                    for iwvl in range(self.kdis_solrad_nwvl):
                        tmp = f.readline()
                        self.kdis_solrad_wvl[iwvl] = float(tmp.split()[0])
                        self.kdis_solrad_F0[iwvl]  = float(tmp.split()[1])
                    f.close()  
                    self.kdis_solrad = True
                    if len(channel_list)>0:
                        self.kdis_solrad_nwvl = len(channel_list)
                        self.kdis_solrad_channels = self.kdis_solrad_channels[channel_list-1]
                        self.kdis_solrad_wvl = self.kdis_solrad_wvl[channel_list-1]
                        self.kdis_solrad_F0  = self.kdis_solrad_F0[channel_list-1]
                          
                else:
                    
                    print("(solar_spectrum) ERROR")
                    print("                 Missing file:", file_path) 
                    exit()
            
            elif file_format == "h5_kdis":

                self.kdis_solrad = True

                ff = h5py.File(file_path, "r")
                self.kdis_solrad_wvl = np.copy( ff["def"]["central_wvl"] )
                self.kdis_solrad_F0  = np.copy( ff["solrad"]["solrad"]   )
                self.solar_spec = ff["solrad"]["solrad"].attrs["solspec"]
                ff.close()
                self.kdis_solrad_nwvl = len(self.kdis_solrad_wvl)
                self.kdis_solrad_channels = np.arange(self.kdis_solrad_nwvl)+1
                self.kdis_solrad_solcste = -1
                if len(channel_list)>0:
                    self.kdis_solrad_nwvl = len(channel_list)
                    self.kdis_solrad_channels = self.kdis_solrad_channels[channel_list-1]
                    self.kdis_solrad_wvl = self.kdis_solrad_wvl[channel_list-1]
                    self.kdis_solrad_F0  = self.kdis_solrad_F0[channel_list-1]

            if not is_sorted(self.kdis_solrad_wvl[:]):
                print(" solrad ERROR")
                print("          wavelengths must be sorted in increasing order")
                exit()


    def select_chan_solrad(self, list_chan):
        if not -1 in list_chan:
            ind_chan=[]
            for ichan, chan in enumerate(list_chan):
                ind_chan.append( np.squeeze(np.argwhere(chan == self.kdis_solrad_channels)) )
            ind_chan = np.array(ind_chan)    
            self.kdis_solrad_nwvl     = len(ind_chan)
            self.kdis_solrad_channels = self.kdis_solrad_channels[ind_chan]
            self.kdis_solrad_wvl      = self.kdis_solrad_wvl[ind_chan]
            self.kdis_solrad_F0       = self.kdis_solrad_F0[ind_chan]
        return 



########################################################################################

class ptcle_optical_properties(object):

    def __init__(self, ptcle_type_list, nbetal, wl, wlref, ind_ref_ang=0, opt_interp=False, verbose=False, read_betal=True):
        
        '''        
        ptcle_type_list : list of dictionnary
                          e.g. ptcle = [{"name":, "file_path":}
                                        {"name":, "file_path":}]

        wl  : is a wavelength interval
              all optical properties found in the given file 
              will be loaded for wavelengths falling in this interval
              e.g [0.70, 0.80]        

        wlref : the reference wavelength for the opacity
        '''

        if len(ptcle_type_list) == 0:
            return
        
        if len(wl) != 2:
            print("(optical_properties) ERROR")
            print("                     wl is a wavelength interval ")
            print("                     It must be [wvl_min, wvl_max]") 
            exit()

        self.type_name = []
        self.file_path = []
        for i in range(len(ptcle_type_list)):
            if ptcle_type_list[i]['name'] not in  self.type_name:
                 self.type_name.append(ptcle_type_list[i]['name'])
                 self.file_path.append(ptcle_type_list[i]['file_path'])
            elif self.file_path[self.type_name.index(ptcle_type_list[i]['name'])] != ptcle_type_list[i]['file_path'] :
                print("(optical_properties) ERROR")
                print("                     Same ptcle_name with different opt_path ")
                print("                     ") 
                exit()
               

        self.ntype = len( self.type_name )
        self.wlref = wlref
        self.wl    = np.sort(wl)
        self.nbetal     = nbetal
        self.opt_interp = opt_interp
        self.trunc_method = [-1]*self.ntype 
        self.cext_reflamb = [-1]*self.ntype 
        self.opt          = [-1]*self.ntype 
        self.read_betal   = read_betal
        self.Sext_o_wc_reflamb      = [-1]*self.ntype 
        self.Sext_o_dry_ac_reflamb  = [-1]*self.ntype 
        self.Sext_o_rh80_ac_reflamb = [-1]*self.ntype 

        
        for i in range(self.ntype):

            if verbose:
                print  ("\n\n(optical_propertie) read opt for ", self.type_name[i], "\n")
            
            if not (self.file_path[i].split(".")[-1] in ["h5", "hdf5"]):
                
                print  ("(optical_propertie) format must be HDF5")
                exit()
                
            else:
                
                filename =   self.file_path[i]              

                if not os.path.isfile(filename):
                    print("")
                    print("(optical_propertie) ERROR")
                    print("            Missing file:", filename)
                    exit()

                ff  = h5py.File(filename,"r")
                
                if self.type_name[i] not in ff.keys():
                    print("")
                    print("(optical_propertie) ERROR")
                    print("           Missing particle ", self.type_name[i], " in file ", filename)
                    exit()

                f = ff[self.type_name[i]]
                                    
                nelem = 4


                ls_datasets = ["Cext",
                               "single_scattering_albedo",
                               "p11_phase_function",
                               "p44_phase_function",
                               "p21_phase_function",
                               "p34_phase_function"]
                
                if read_betal:
                    
                    if (f["data"]["alpha1_betal"].attrs.get('truncation').decode() != "dm") :                    
                        print("")
                        print("(optical_propertie) ERROR")
                        print("           pyARTDECO only suited for DeltaM truncated optical prop. library")
                        print("")
                        exit()
                        
                    ls_datasets = ls_datasets + ["alpha1_betal",
                                                 "alpha2_betal",
                                                 "alpha3_betal",
                                                 "alpha4_betal",
                                                 "beta1_betal",
                                                 "beta2_betal",
                                                 "recomposed_p11_phase_function",
                                                 "recomposed_p22_phase_function",
                                                 "recomposed_p33_phase_function",
                                                 "recomposed_p44_phase_function",
                                                 "recomposed_p21_phase_function",
                                                 "recomposed_p34_phase_function",
                                                 "truncation_coefficient",
                                                 "integrate_p11_recomp"]                    

                if "extinction_coefficient_over_wc" in f["data"].keys():
                    ls_datasets.append("extinction_coefficient_over_wc")
                if "extinction_coefficient_over_dry_ac" in  f["data"].keys():
                    ls_datasets.append("extinction_coefficient_over_dry_ac")
                if "extinction_coefficient_over_rh80_ac" in  f["data"].keys():
                    ls_datasets.append("extinction_coefficient_over_rh80_ac")

                if "p22_phase_function" in f["data"].keys():
                    nelem = 6
                    ls_datasets.append("p22_phase_function")
                    ls_datasets.append("p33_phase_function")

                if read_betal:
                    self.trunc_method[i] = f["data"]["alpha1_betal"].attrs.get('truncation').decode()
                
                ff.close()

                m = read_mlut_hdf5(filename, ls_datasets, group=self.type_name[i], lazy=True)
                
                if (m.axes["wavelengths"] != np.sort(m.axes["wavelengths"])).all():
                    print("")
                    print("(optical_propertie) ERROR")
                    print("           Wavelength axe should be sorted in increasing order in file ", filename)
                    exit()

                
                iwvl1 = find_nearest(m.axes["wavelengths"], wl[0] )
                iwvl2 = find_nearest(m.axes["wavelengths"], wl[1] )

                #print(iwvl1, iwvl2)
                
                if m.axes["wavelengths"][iwvl1] > wl[0] and iwvl1!=0:
                    iwvl1 = iwvl1-1
                if m.axes["wavelengths"][iwvl2] < wl[1] and iwvl2!=len(m.axes["wavelengths"])-1:
                    iwvl2 = iwvl2+1

                # print(iwvl1, iwvl2)                    
                # print(m.axes["wavelengths"][slice(iwvl1,iwvl2+1)])
                # exit()
                
                sub_lut_dic = {"wavelengths":slice(iwvl1,iwvl2+1) }

                
                if (not self.opt_interp) and not (self.wlref in m.axes["wavelengths"]):
                    print("")
                    print("(optical_propertie) ERROR")
                    print("          Required reference wavelength value is not sampled in ", filename)
                    print("          and no interpolation allowed ")
                    exit()
                elif (self.wlref in m.axes["wavelengths"]):
                    sub_reflamb_dic = {"wavelengths":Idx(self.wlref, round=True)}
                else:
                    sub_reflamb_dic = {"wavelengths":Idx(self.wlref)}

                if read_betal:
                    if not self.nbetal in m.axes["nbetal"]:
                        print("")
                        print("(optical_propertie) ERROR")
                        print("          Required -nbetal- value is not sampled in ", filename)
                        print("          ")
                        exit() 
                    else:
                        sub_lut_dic.update( {"nbetal" : Idx(self.nbetal, round=True)} )

                
                self.cext_reflamb[i] = m["Cext"].sub(sub_reflamb_dic)
                
                if verbose:
                    print("\n")
                    self.cext_reflamb[i].print_info()
                
                if "extinction_coefficient_over_wc" in  ls_datasets:
                    self.Sext_o_wc_reflamb[i] = m["extinction_coefficient_over_wc"].sub(sub_reflamb_dic)
                if "extinction_coefficient_over_dry_ac" in  ls_datasets:
                    self.Sext_o_dry_ac_reflamb[i] = m["extinction_coefficient_over_dry_ac"].sub(sub_reflamb_dic)
                if "extinction_coefficient_over_rh80_ac" in  ls_datasets:
                    self.Sext_o_rh80_ac_reflamb[i] = m["extinction_coefficient_over_rh80_ac"].sub(sub_reflamb_dic)
                    
                self.opt[i] = m.sub(sub_lut_dic)
                
                if verbose:
                    print("\n")
                    self.opt[i].print_info()
                                
        # We need to interpolate the phase matrix onto the same angular grid
        # we use the grid for the particle -ind_ref_ang-
        for i in range(self.ntype):
            if i == ind_ref_ang:
                continue
            if not np.array_equal(self.opt[ind_ref_ang].axes["mu"], self.opt[i].axes["mu"]):
                if verbose:
                    print ("\n\n(optical_propertie) angular interpolation for ",self.type_name[i], "\n")
                self.opt[i] = self.opt[i].sub({"mu":Idx(self.opt[ind_ref_ang].axes["mu"])}) 
                
        

########################################################################################

class particle(object):
    
    def __init__(self, artdeco_in, ptcle, optical_prop):

        '''        

        wlref is the reference wavelength for the opacity

        ptcle is a list of dictionnary. One dictionnary per particle to use.
              e.g. [ {"file":string, "name":string, "tau":float, "alt_distrib":string, "z":float, "dz":float, "user_vdist":np.array, "user_alt":np.array, "reff":float, "veff":float, "humidity":float} ]
              If alt_distrib is "layer"         z (z top) must be given 
                             is  "homogeneous"  z (z top) and dz must be given
                             is  "gauss"        z (z middle) and dz must be given
                             is  "sh"           z (scale height) must be given
                             is  "user"         "user_vdist" and "user_alt" must be given

        optical_prop :  optical properties class

        '''

        #print "Set particle class..."

        if len(ptcle) == 0:
            #print "(particle) no particle"
            self.nptcle = 0
            return        
            
        self.nptcle = len(ptcle)
        
        self.type   = [x["name"] for x in ptcle] 

        self.tau_reflamb = np.array([x["tau"] for x in ptcle])

        # HG approx not implemented
        self.hg   = [False]*self.nptcle
        # whether wavelength interpolation will be allowed 
        self.opt_interp = optical_prop.opt_interp
        self.wlref      = optical_prop.wlref
        
        #######################
        # vertical distribution
        self.vdist_type     = [x["alt_distrib"] for x in ptcle] 
        self.h_vdist        = np.zeros_like( self.tau_reflamb)
        self.sd_gauss_vdist = np.zeros_like( self.tau_reflamb)
        
        for i in range(self.nptcle):

            if self.vdist_type[i] not in ['layer','homogeneous','sh','gauss','user']: 
                print("(particle) ERROR")
                print("           -alt_distrib- must be layer, homogeneous, sh, gauss or user")
                print("           required:",  self.vdist_type[i])
                print("")
                exit()

            if self.vdist_type[i] in ['layer','homogeneous','sh','gauss']: 
                self.h_vdist[i] = ptcle[i]["z"]
            else:
                self.h_vdist[i] = -1.0
            if self.vdist_type[i] in ['homogeneous','gauss']: 
                if not "dz" in ptcle[i].keys():
                    print("")
                    print("(particle) ERROR")
                    print("           dz must be provided for -homogeneous- and -gauss- vertical distribution")
                    print("           for particle ",  ptcle[i]["name"])
                    print("")
                    exit()
                self.sd_gauss_vdist[i] = ptcle[i]["dz"]
            else:
                self.sd_gauss_vdist[i] = -1.0

        if 'user' in self.vdist_type:
            naltmax = 0
            for x in ptcle:
                if x["alt_distrib"] == 'user':
                    if len(x["user_alt"]) > naltmax: 
                        naltmax = len(x["user_alt"])
            self.user_vdist_nalt = np.zeros(self.nptcle, dtype='int')
            self.user_vdist_alt = np.zeros((self.nptcle, naltmax))
            self.user_vdist     = np.zeros((self.nptcle, naltmax))
            for i in range(self.nptcle):
                if ptcle[i]["alt_distrib"] == 'user':
                    self.user_vdist_nalt[i]  = len(ptcle[i]["user_alt"]) 
                    self.user_vdist_alt[i,:] = ptcle[i]["user_alt"] 
                    self.user_vdist[i,:]     = ptcle[i]["user_vdist"] 

        
        ####################            
        # optical properties   

        #print(optical_prop.__dict__)
        
        self.optical_properties = []

        for i in range(self.nptcle):

            if self.type[i] not in optical_prop.type_name:
                print("")
                print("(particle) ERROR")
                print("            Missing optical properties for :", self.type[i] )
                exit()
            
            ii = np.squeeze(np.argwhere(self.type[i] == np.array(optical_prop.type_name)))
                        
            opt_wvl    = optical_prop.opt[ii].axes["wavelengths"]
            nopt_wvl   = len(opt_wvl)
            u_phasemat    = np.transpose(np.array([ optical_prop.opt[ii].axes["mu"] for iii in range(nopt_wvl) ]))
            nang_phasemat = np.full(nopt_wvl, len(optical_prop.opt[ii].axes["mu"]), dtype=int)
            opt           = np.zeros((3, nopt_wvl))
            phasemat      = np.zeros((6, np.max(nang_phasemat), nopt_wvl))

            sub_lut_dic     = {}
            sub_reflamb_dic = {}
            
            prop_list = ["humidity", "reff", "veff", "temp", "logiwc"]
            for prop in prop_list:
                if prop in optical_prop.opt[ii]["p11_phase_function"].names:
                    if not prop in ptcle[i].keys():
                        print("")
                        print("(particle) ERROR")
                        print("          ",prop," should be provided in the description (dictionnary) for ", ptcle[i]["name"]) 
                        exit()
                    if (not self.opt_interp) and not (ptcle[i][prop] in optical_prop.opt[ii].axes[prop]):
                        print("")
                        print("(particle) ERROR")
                        print("          Required -",prop,"- value is not sampled")
                        print("          and no interpolation allowed ")
                        print(ptcle[i][prop])
                        print(optical_prop.opt[ii].axes[prop])
                        exit()                        
                    elif (ptcle[i][prop] in optical_prop.opt[ii].axes[prop]):                        
                        # Idx will round and return an int rather
                        # than an float.  There will then be no interpolation  (more efficient)
                        sub_lut_dic.update( { prop: Idx(ptcle[i][prop], round=True) } )
                        sub_reflamb_dic.update( { prop: Idx(ptcle[i][prop], round=True) } )
                    else:
                        sub_lut_dic.update( { prop: Idx(ptcle[i][prop]) } )
                        sub_reflamb_dic.update( { prop: Idx(ptcle[i][prop]) } )

                    

            datasets = [x[0] for x in optical_prop.opt[ii].data]


            if len(sub_reflamb_dic.keys()) > 0:
                cext_reflamb = optical_prop.cext_reflamb[ii].sub(sub_reflamb_dic).data
            else:
                cext_reflamb = optical_prop.cext_reflamb[ii].data

            if len(sub_lut_dic.keys()) > 0:
                opt[0,:] = optical_prop.opt[ii]["Cext"].sub(sub_lut_dic).data
                opt[1,:] = optical_prop.opt[ii]["single_scattering_albedo"].sub(sub_lut_dic).data
                opt[2,:] = -32768
                phasemat[0, :, :] = optical_prop.opt[ii]["p11_phase_function"].sub(sub_lut_dic).data  # F11
                phasemat[3, :, :] = optical_prop.opt[ii]["p44_phase_function"].sub(sub_lut_dic).data  # F44
                phasemat[4, :, :] = optical_prop.opt[ii]["p21_phase_function"].sub(sub_lut_dic).data  # F21
                phasemat[5, :, :] = optical_prop.opt[ii]["p34_phase_function"].sub(sub_lut_dic).data  # F34
                if "p22_phase_function" in datasets:
                    nelem = 6
                    phasemat[1, :, :] =  optical_prop.opt[ii]["p22_phase_function"].sub(sub_lut_dic).data # F22
                    phasemat[2, :, :] =  optical_prop.opt[ii]["p33_phase_function"].sub(sub_lut_dic).data # F33
                else:
                    nelem = 4
                    phasemat[1, :, :] = phasemat[0, :, :] # F22
                    phasemat[2, :, :] = phasemat[3, :, :] # F33
            else:
                opt[0,:] = optical_prop.opt[ii]["Cext"].data
                opt[1,:] = optical_prop.opt[ii]["single_scattering_albedo"].data
                opt[2,:] = -32768
                phasemat[0, :, :] = optical_prop.opt[ii]["p11_phase_function"].data  # F11
                phasemat[3, :, :] = optical_prop.opt[ii]["p44_phase_function"].data  # F44
                phasemat[4, :, :] = optical_prop.opt[ii]["p21_phase_function"].data  # F21
                phasemat[5, :, :] = optical_prop.opt[ii]["p34_phase_function"].data  # F34
                if "p22_phase_function" in datasets:
                    nelem = 6
                    phasemat[1, :, :] =  optical_prop.opt[ii]["p22_phase_function"].data # F22
                    phasemat[2, :, :] =  optical_prop.opt[ii]["p33_phase_function"].data # F33
                else:
                    nelem = 4
                    phasemat[1, :, :] = phasemat[0, :, :] # F22
                    phasemat[2, :, :] = phasemat[3, :, :] # F33



            if optical_prop.read_betal:
                
                if (optical_prop.trunc_method[ii]!="dm") or (artdeco_in.trunc_method != "dm"):                    
                    print("")
                    print("(particle) ERROR")
                    print("           pyARTDECO only suited for DeltaM truncated optical prop. library")
                    print("")
                    exit()

                trunc_method = optical_prop.trunc_method[ii]

                if artdeco_in.nstreams != optical_prop.nbetal:
                    print("")
                    print("(particle) ERROR")
                    print("          Required -nbetal- value is not sampled in optical properties")
                    print("          ")
                    exit() 

                betal          = np.zeros((6, artdeco_in.nstreams+1, nopt_wvl))
                trunc_coeff    = np.zeros(nopt_wvl)
                trunc_normphasemat = np.zeros(nopt_wvl)
                trunc_phasemat = np.zeros((6, np.max(nang_phasemat), nopt_wvl)) 


                if len(sub_lut_dic.keys()) > 0:
                    
                    betal[0,:,:] = optical_prop.opt[ii]["alpha1_betal"].sub(sub_lut_dic).data[0:artdeco_in.nstreams+1,:]
                    betal[1,:,:] = optical_prop.opt[ii]["alpha2_betal"].sub(sub_lut_dic).data[0:artdeco_in.nstreams+1,:]
                    betal[2,:,:] = optical_prop.opt[ii]["alpha3_betal"].sub(sub_lut_dic).data[0:artdeco_in.nstreams+1,:]
                    betal[3,:,:] = optical_prop.opt[ii]["alpha4_betal"].sub(sub_lut_dic).data[0:artdeco_in.nstreams+1,:]
                    betal[4,:,:] = optical_prop.opt[ii]["beta1_betal"].sub(sub_lut_dic).data[0:artdeco_in.nstreams+1,:]
                    betal[5,:,:] = optical_prop.opt[ii]["beta2_betal"].sub(sub_lut_dic).data[0:artdeco_in.nstreams+1,:]
                    trunc_phasemat[0, :, :] = optical_prop.opt[ii]["recomposed_p11_phase_function"].sub(sub_lut_dic).data  # F11
                    trunc_phasemat[1, :, :] = optical_prop.opt[ii]["recomposed_p22_phase_function"].sub(sub_lut_dic).data  # F22
                    trunc_phasemat[2, :, :] = optical_prop.opt[ii]["recomposed_p33_phase_function"].sub(sub_lut_dic).data  # F33
                    trunc_phasemat[3, :, :] = optical_prop.opt[ii]["recomposed_p44_phase_function"].sub(sub_lut_dic).data  # F44
                    trunc_phasemat[4, :, :] = optical_prop.opt[ii]["recomposed_p21_phase_function"].sub(sub_lut_dic).data  # F21
                    trunc_phasemat[5, :, :] = optical_prop.opt[ii]["recomposed_p34_phase_function"].sub(sub_lut_dic).data  # F34
                    trunc_coeff[:] = optical_prop.opt[ii]["truncation_coefficient"].sub(sub_lut_dic).data
                    trunc_normphasemat[:] = optical_prop.opt[ii]["integrate_p11_recomp"].sub(sub_lut_dic).data

                else:

                    betal[0,:,:] = optical_prop.opt[ii]["alpha1_betal"].data[0:artdeco_in.nstreams+1,:]
                    betal[1,:,:] = optical_prop.opt[ii]["alpha2_betal"].data[0:artdeco_in.nstreams+1,:]
                    betal[2,:,:] = optical_prop.opt[ii]["alpha3_betal"].data[0:artdeco_in.nstreams+1,:]
                    betal[3,:,:] = optical_prop.opt[ii]["alpha4_betal"].data[0:artdeco_in.nstreams+1,:]
                    betal[4,:,:] = optical_prop.opt[ii]["beta1_betal"].data[0:artdeco_in.nstreams+1,:]
                    betal[5,:,:] = optical_prop.opt[ii]["beta2_betal"].data[0:artdeco_in.nstreams+1,:]
                    trunc_phasemat[0, :, :] = optical_prop.opt[ii]["recomposed_p11_phase_function"].data  # F11
                    trunc_phasemat[1, :, :] = optical_prop.opt[ii]["recomposed_p22_phase_function"].data  # F22
                    trunc_phasemat[2, :, :] = optical_prop.opt[ii]["recomposed_p33_phase_function"].data  # F33
                    trunc_phasemat[3, :, :] = optical_prop.opt[ii]["recomposed_p44_phase_function"].data  # F44
                    trunc_phasemat[4, :, :] = optical_prop.opt[ii]["recomposed_p21_phase_function"].data  # F21
                    trunc_phasemat[5, :, :] = optical_prop.opt[ii]["recomposed_p34_phase_function"].data  # F34
                    trunc_coeff[:] = optical_prop.opt[ii]["truncation_coefficient"].data
                    trunc_normphasemat[:] = optical_prop.opt[ii]["integrate_p11_recomp"].data

                
                self.optical_properties.append({"nelem":nelem, 
                                                "cext_reflamb":cext_reflamb, 
                                                "opt_wvl":opt_wvl, 
                                                "nang_phasemat":nang_phasemat, 
                                                "opt":opt, 
                                                "u_phasemat":u_phasemat,
                                                "phasemat":phasemat,
                                                "betal_flag":True,
                                                "nbetal":artdeco_in.nstreams,
                                                "trunc_method":trunc_method,
                                                "betal":betal,
                                                "trunc_coeff":trunc_coeff,
                                                "trunc_phasemat":trunc_phasemat,
                                                "trunc_normphasemat":trunc_normphasemat})        

            else:


                self.optical_properties.append({"nelem":nelem, 
                                                "cext_reflamb":cext_reflamb, 
                                                "opt_wvl":opt_wvl, 
                                                "nang_phasemat":nang_phasemat, 
                                                "opt":opt, 
                                                "u_phasemat":u_phasemat,
                                                "phasemat":phasemat,
                                                "betal_flag":False})


                    
            self.nang_phasemat =  len(optical_prop.opt[ii].axes["mu"])
            self.u_phasemat    = optical_prop.opt[ii].axes["mu"]




    def get_opt(self, wl, nbetal, trunc_method, ascii_save=False, dir_ascii_save="/tmp/"):

        nwl      = len(wl)
        phasemat = np.zeros((self.nptcle, 6, self.nang_phasemat, nwl))
        opt      = np.zeros((self.nptcle, 3, nwl))

        betal_flag = np.full( (self.nptcle, nwl), False, dtype=bool )
        if nbetal > 0:
            betal      = np.zeros((self.nptcle, 6, nbetal+1, nwl))   
            trunccoeff = np.zeros((self.nptcle, nwl))   
            trunc_phasemat = np.zeros((self.nptcle, 6, self.nang_phasemat, nwl))
            trunc_normphasemat = np.zeros((self.nptcle, nwl))  
        else:
            betal      = None 
            trunccoeff = None
            trunc_phasemat     = None 
            trunc_normphasemat = None

        for iptcle in range(self.nptcle):

            lut_phasemat = LUT(self.optical_properties[iptcle]["phasemat"], 
                               axes =[np.arange(6.0), self.u_phasemat,  self.optical_properties[iptcle]["opt_wvl"]], 
                               names=["imat","mu","wvl"])
            
            lut_opt = LUT(self.optical_properties[iptcle]["opt"], 
                          axes =[np.arange(3.0),  self.optical_properties[iptcle]["opt_wvl"]],
                          names=["iopt","wvl"])

            if (nbetal>0) and self.optical_properties[iptcle]["betal_flag"] and (trunc_method == self.optical_properties[iptcle]["trunc_method"]) and (nbetal ==  self.optical_properties[iptcle]["nbetal"]):
                betal_flag[iptcle,:] = True

                lut_betal = LUT(self.optical_properties[iptcle]["betal"], 
                                axes =[np.arange(6.0), np.arange(self.optical_properties[iptcle]["nbetal"]+1),  self.optical_properties[iptcle]["opt_wvl"]], 
                                names=["imat","l","wvl"])
                lut_trunccoeff = LUT(self.optical_properties[iptcle]["trunc_coeff"], 
                                     axes =[self.optical_properties[iptcle]["opt_wvl"]], 
                                     names=["wvl"])
                lut_trunc_phasemat = LUT(self.optical_properties[iptcle]["trunc_phasemat"], 
                                         axes =[np.arange(6.0), self.u_phasemat,  self.optical_properties[iptcle]["opt_wvl"]], 
                                         names=["imat","mu","wvl"])
                lut_trunc_normphasemat = LUT(self.optical_properties[iptcle]["trunc_normphasemat"], 
                                             axes =[self.optical_properties[iptcle]["opt_wvl"]], 
                                             names=["wvl"])
            for iwl in range(nwl):

                if (not self.opt_interp) and not (wl[iwl] in self.optical_properties[iptcle]["opt_wvl"]):
                    print(' (particle.get_it) ERROR')
                    print('                   No wavelength interpolation allowed in the class definition')
                    exit()
                elif (wl[iwl] in self.optical_properties[iptcle]["opt_wvl"]):
                    tmp = {"wvl" : Idx(wl[iwl], round=True)}
                else:
                    tmp = {"wvl" : Idx(wl[iwl])}
                        
                opt[iptcle,:,iwl]        = lut_opt.sub(tmp).data
                phasemat[iptcle,:,:,iwl] = lut_phasemat.sub(tmp).data
                
                if  betal_flag[iptcle,iwl]:
                    betal[iptcle,:,:,iwl]            = lut_betal.sub(tmp).data
                    trunccoeff[iptcle,iwl]           = lut_trunccoeff.sub(tmp).data
                    trunc_phasemat[iptcle,:,:,iwl]   = lut_trunc_phasemat.sub(tmp).data
                    trunc_normphasemat[iptcle,iwl]   = lut_trunc_normphasemat.sub(tmp).data


        if ascii_save:
            
            for iptcle in range(self.nptcle):
                
                # Write opt_ file to be used for th betal expension
                fopt = open(dir_ascii_save+'opt_'+self.type[iptcle]+'.dat', 'w')
                fopt.write('# Optical properties to be used by ARTDECO \n')
                fopt.write('# \n')
                fopt.write('# Used model to obtain that properties is: \n')
                fopt.write(' unknown \n')
                fopt.write('# Used material is: \n')
                fopt.write(' unknown \n')
                fopt.write('# Number of phase matrix elements : \n')
                fopt.write(' 6 \n')
                fopt.write('# number of wavelengths \n')
                fopt.write('%i'%nwl+'  \n')
                fopt.write('#  lambda(microns)   nteta   Cext (microns^2)       SSA              g  \n')
                for iwvl in range(nwl):
                    s = '   %.9e'%wl[iwvl]+'   %i'%self.nang_phasemat+'   %.9e'%opt[iptcle,0, iwvl]+'   %.9e'%opt[iptcle,1, iwvl]+'   %.9e'%opt[iptcle,2, iwvl]+' \n'
                    fopt.write(s)
                fopt.write('# Phase matrix  \n')
                for iwvl in range(nwl): 
                    fopt.write('# lambda = %.4f'%wl[iwvl]+' \n')
                    fopt.write('#         u              F11            F22             F33            F44               F21               F34      \n')
                    for iang in range(self.nang_phasemat):
                        s = '   %.17e'%(self.u_phasemat[iang])+'   %.17e'%phasemat[iptcle, 0, iang, iwvl]+'   %.17e'%phasemat[iptcle, 1, iang, iwvl]+'   %.17e'%phasemat[iptcle, 2, iang, iwvl]+'   %.17e'%phasemat[iptcle, 3, iang, iwvl]+'   %.17e'%phasemat[iptcle, 4, iang, iwvl]+'   %.17e'%phasemat[iptcle, 5, iang, iwvl]+" \n"
                        fopt.write(s)       
                fopt.close()

        return opt, self.u_phasemat, phasemat, betal, betal_flag, trunccoeff, trunc_phasemat, trunc_normphasemat 

    
###############################################################################################


class kdis_coeff(object):
    
    def __init__(self, artdeco_in, dir_data, format, channel_list=[]):
        
        # read the entire K-distribution definition from files
        #
        # Selection of the desired KDIS band or absorbing gases
        # must be done later while setting up the artdeco variables

        #print "Set kdis class ..."

        if artdeco_in.mode != 'kdis':
            #print "    no kdis"
            return

        self.model = artdeco_in.kdis_model   

        is_sorted = lambda a: np.all(a[:-1] <= a[1:])        
        
        if format == "h5":

            filename = dir_data+'kdis_'+artdeco_in.kdis_model+'.h5'

            f = h5py.File(filename,"r")

            self.nmaxai = np.copy(f["def"]["maxnai"])

            species_tot = list(f["coeff"].keys())
            self.nsp_tot =  len(species_tot)

            self.nsp     = 0
            self.fcont   = []
            self.species = []

            self.nsp_c     = 0
            self.fcont_c   = []
            self.species_c = []

            for isp, specie in enumerate(species_tot):

                if "rho_dep" in list(f["coeff"][specie].attrs.keys()):
                    if f["coeff"][specie].attrs['rho_dep']:
                        rho_dep = True
                    else:
                        rho_dep = False
                else:
                    rho_dep = False
                   
                if not rho_dep:
                    self.nsp  = self.nsp  + 1
                    self.species.append(specie)
                    self.fcont.append( f["coeff"][specie].attrs['add_continuum'] )
                else:
                    self.nsp_c  = self.nsp_c  + 1
                    self.species_c.append(specie)
                    self.fcont_c.append(  f["coeff"][specie].attrs['add_continuum'] )
                    
            self.fcont   = np.array(self.fcont)        
            self.fcont_c = np.array(self.fcont_c)        

            self.nwvl = len(f['def']['central_wvl'])
            self.wvlband = np.zeros((3, self.nwvl))
            self.wvlband[0,:] = np.copy(f['def']['central_wvl'])
            self.wvlband[1,:] = np.copy(f['def']['min_wvl'])
            self.wvlband[2,:] = np.copy(f['def']['max_wvl'])
            self.channels     = np.arange(self.nwvl)+1
            self.p = np.copy(f['def']['pressure'])
            self.t = np.copy(f['def']['temperature'])
            self.np = len(self.p)
            self.nt = len(self.t)            
            if self.nsp_c > 0:
                self.c = np.copy(f['def']['rho'])
                self.c_desc = f['def']['rho'].attrs["desc"].decode()
                self.nc = len(self.c)
                if not is_sorted(self.c):
                    print(" kdis_coeff ERROR")
                    print("            concentration must be sorted in increasing order")
                    exit()
            else:
                self.c_desc = "none"

            if not is_sorted(self.wvlband[0,:]):
                print(" kdis_coeff ERROR")
                print("            (h5 format) wavelengths must be sorted in increasing order")
                exit()
            if not is_sorted(self.p):
                print(" kdis_coeff ERROR")
                print("            pressure must be sorted in increasing order")
                exit()
            if not is_sorted(self.t):
                print(" kdis_coeff ERROR")
                print("            temperature must be sorted in increasing order")
                exit()

            if len(channel_list) > 0:
                
                self.nwvl     = len(channel_list)
                self.wvlband  = self.wvlband[:,channel_list-1]
                self.channels = self.channels[channel_list-1]

                if self.nsp>0:
                    self.nai   = np.zeros((self.nsp,self.nwvl), dtype='int')
                    self.ki    = np.zeros((self.nsp,self.nwvl,self.nmaxai,self.np,self.nt))
                    self.ai    = np.zeros((self.nsp,self.nwvl,self.nmaxai))
                    self.xsect = np.zeros((self.nsp,self.nwvl))
                    for isp, specie in enumerate(self.species):
                        for iwvl, iwvl_file in enumerate(channel_list-1):
                            self.nai[isp,iwvl]      = np.copy(f["coeff"][specie]["nai"][iwvl_file])
                            self.ki[isp,iwvl,0:self.nai[isp,iwvl],:,:] = np.copy(f["coeff"][specie]["ki"][iwvl_file,0:self.nai[isp,iwvl] ,:,:]) 
                            self.ai[isp,iwvl,0:self.nai[isp,iwvl] ]     = np.copy(f["coeff"][specie]["ai"][iwvl_file,0:self.nai[isp,iwvl] ]) 
                            self.xsect[isp,iwvl]    = np.copy(f["coeff"][specie]["xsect"][iwvl_file])
                if self.nsp_c>0:
                    self.nai_c   = np.zeros((self.nsp_c,self.nwvl), dtype='int')
                    self.ki_c    = np.zeros((self.nsp_c,self.nwvl,self.nmaxai,self.np,self.nt,self.nc))
                    self.ai_c    = np.zeros((self.nsp_c,self.nwvl,self.nmaxai))
                    self.xsect_c = np.zeros((self.nsp_c,self.nwvl))
                    for isp, specie in enumerate(self.species_c):
                        for iwvl, iwvl_file in enumerate(channel_list-1):
                            self.nai_c[isp,iwvl]      = np.copy(f["coeff"][specie]["nai"][iwvl_file])
                            self.ki_c[isp,iwvl,0:self.nai_c[isp,iwvl],:,:,:] = np.copy(f["coeff"][specie]["ki"][iwvl_file,0:self.nai_c[isp,iwvl] ,:,:,:]) 
                            self.ai_c[isp,iwvl,0:self.nai_c[isp,iwvl] ]     = np.copy(f["coeff"][specie]["ai"][iwvl_file,0:self.nai_c[isp,iwvl] ]) 
                            self.xsect_c[isp,iwvl]    = np.copy(f["coeff"][specie]["xsect"][iwvl_file])

            else:    

                if self.nsp>0:
                    self.nai   = np.zeros((self.nsp,self.nwvl), dtype='int')
                    self.ki    = np.zeros((self.nsp,self.nwvl,self.nmaxai,self.np,self.nt))
                    self.ai    = np.zeros((self.nsp,self.nwvl,self.nmaxai))
                    self.xsect = np.zeros((self.nsp,self.nwvl))
                    for isp, specie in enumerate(self.species):
                        self.nai[isp,:]      = np.copy(f["coeff"][specie]["nai"])
                        nai_tmp = np.nanmax(self.nai[isp,:])    
                        self.ki[isp,:,0:nai_tmp,:,:] = np.copy(f["coeff"][specie]["ki"][:,0:nai_tmp,:,:]) 
                        self.ai[isp,:,0:nai_tmp]     = np.copy(f["coeff"][specie]["ai"][:,0:nai_tmp]) 
                        self.xsect[isp,:]    = np.copy(f["coeff"][specie]["xsect"]) 
                
                if self.nsp_c>0:
                    self.nai_c   = np.zeros((self.nsp_c,self.nwvl), dtype='int')
                    self.ki_c    = np.zeros((self.nsp_c,self.nwvl,self.nmaxai,self.np,self.nt,self.nc))
                    self.ai_c    = np.zeros((self.nsp_c,self.nwvl,self.nmaxai))
                    self.xsect_c = np.zeros((self.nsp_c,self.nwvl))
                    for isp, specie in enumerate(self.species_c):
                        self.nai_c[isp,:]      = np.copy(f["coeff"][specie]["nai"])
                        nai_tmp = np.nanmax(self.nai_c[isp,:])    
                        self.ki_c[isp,:,0:nai_tmp,:,:] = np.copy(f["coeff"][specie]["ki"][:,0:nai_tmp,:,:]) 
                        self.ai_c[isp,:,0:nai_tmp]     = np.copy(f["coeff"][specie]["ai"][:,0:nai_tmp]) 
                        self.xsect_c[isp,:]    = np.copy(f["coeff"][specie]["xsect"]) 

            f.close()

            for isp, specie in enumerate(self.species):
                self.species[isp] = self.species[isp].lower()
 
            for isp, specie in enumerate(self.species_c):
                self.species_c[isp] = self.species_c[isp].lower()
            
        elif format == 'ascii':
            
            filename = dir_data+'kdis_'+artdeco_in.kdis_model+'_def.dat'
            if not os.path.isfile(filename):
                print("(kdis_coef) ERROR")
                print("            Missing file:", filename)
                exit()
            fdef = open(filename,'r')
            skipcomment(fdef)
            tmp = fdef.readline()
            self.nmaxai = int(tmp.split()[0])
            skipcomment(fdef)
            tmp = fdef.readline()
            
            self.nsp_tot =  int(tmp.split()[0])

            self.nsp     = 0
            self.fcont   = []
            self.species = []

            self.nsp_c     = 0
            self.fcont_c   = []
            self.species_c = []

            skipcomment(fdef)
            for i in range(self.nsp_tot):
                tmp = fdef.readline()
                if int(tmp.split()[1]) == 0:
                    self.nsp  = self.nsp  + 1
                    self.species.append(tmp.split()[0])
                    self.fcont.append( float(tmp.split()[2] ) )
                elif  int(tmp.split()[1]) == 1:
                    self.nsp_c  = self.nsp_c  + 1
                    self.species_c.append(tmp.split()[0])
                    self.fcont_c.append( float(tmp.split()[2] ) )
                    
            self.fcont   = np.array(self.fcont)        
            self.fcont_c = np.array(self.fcont_c)        

            skipcomment(fdef)
            tmp = fdef.readline()
            self.nwvl = int(tmp.split()[0])
            self.wvlband  = np.zeros((3, self.nwvl))
            self.channels = np.arange(self.nwvl)+1
            skipcomment(fdef)
            for i in range(self.nwvl):
                tmp = fdef.readline()                
                self.wvlband[0,i] = float(tmp.split()[1])
                self.wvlband[1,i] = float(tmp.split()[2])
                self.wvlband[2,i] = float(tmp.split()[3])
                if i>0:
                    if self.wvlband[0,i] < self.wvlband[0,i-1]:
                        print(" kdis_coeff ERROR")
                        print("            wavelengths must be sorted in increasing order")
                        exit()
            skipcomment(fdef)
            tmp = fdef.readline()
            skipcomment(fdef)
            self.np = int(tmp.split()[0])
            self.p = np.zeros(self.np)
            for i in range(self.np):
                tmp = fdef.readline()                
                self.p[i] = float(tmp.split()[0])
                if i>0:
                    if self.p[i] < self.p[i-1]:
                        print(" kdis_coeff ERROR")
                        print("            pressure must be sorted in increasing order")
                        exit()
            skipcomment(fdef)
            tmp = fdef.readline()
            skipcomment(fdef)
            self.nt = int(tmp.split()[0])
            self.t = np.zeros(self.nt)
            for i in range(self.nt):
                tmp = fdef.readline()                
                self.t[i] = float(tmp.split()[0])
                if i>0:
                    if self.t[i] < self.t[i-1]:
                        print(" kdis_coeff ERROR")
                        print("            temperature must be sorted in increasing order")
                        exit()
            if self.nsp_c > 0:
                skipcomment(fdef)
                tmp = fdef.readline()
                skipcomment(fdef)
                self.nc = int(tmp.split()[0])
                self.c = np.zeros(self.nc)
                for i in range(self.nc):
                    tmp = fdef.readline()                
                    self.c[i] = float(tmp.split()[0])
                    if i>0:
                        if self.c[i] < self.c[i-1]:
                            print(" kdis_coeff ERROR")
                            print("            concentration must be sorted in increasing order")
                            exit()

            fdef.close()

            if self.nsp > 0:
                self.nai   = np.zeros((self.nsp,self.nwvl), dtype='int')
                self.ki    = np.zeros((self.nsp,self.nwvl,self.nmaxai,self.np,self.nt))
                self.ai    = np.zeros((self.nsp,self.nwvl,self.nmaxai))
                self.xsect = np.zeros((self.nsp,self.nwvl))
            if self.nsp_c > 0:
                self.nai_c   = np.zeros((self.nsp_c,self.nwvl), dtype='int')
                self.ki_c    = np.zeros((self.nsp_c,self.nwvl,self.nmaxai,self.np,self.nt,self.nc))
                self.ai_c    = np.zeros((self.nsp_c,self.nwvl,self.nmaxai))
                self.xsect_c = np.zeros((self.nsp_c,self.nwvl))

            for isp in range(self.nsp):
                filename = dir_data+'kdis_'+artdeco_in.kdis_model+'_'+self.species[isp]+'.dat'
                if not os.path.isfile(filename):
                    print("(kdis_coef) ERROR")
                    print("            Missing file:", filename)
                    exit()                    
                f = open(filename,'r')
                skipcomment(f)
                for iwvl in range(self.nwvl):
                    tmp = f.readline()
                    self.nai[isp,iwvl]   = int(tmp.split()[1])
                    self.xsect[isp,iwvl] = float(tmp.split()[2])
                for iwvl in range(self.nwvl):
                    if self.nai[isp,iwvl]>1:
                        skipcomment(f)
                        tmp = f.readline()
                        #print 'nai, nmaxai=',self.nai[isp,iwvl], self.nmaxai
                        for iai in range(self.nai[isp,iwvl]):
                            #print iai, float(tmp.split()[iai])  
                            self.ai[isp,iwvl,iai] = float(tmp.split()[iai])  

                        for it in range(self.nt):
                            for ip in range(self.np):
                                tmp = f.readline()
                                for iai in range(self.nai[isp,iwvl]):
                                    self.ki[isp,iwvl,iai,ip,it] = float(tmp.split()[iai])                                  
                f.close()

            if self.nsp_c > 0:
                self.c_desc = "density"
            else:
                self.c_desc = "none"
                
            for isp in range(self.nsp_c):
                filename = dir_data+'kdis_'+artdeco_in.kdis_model+'_'+self.species_c[isp]+'.dat'
                if not os.path.isfile(filename):
                    print("(kdis_coef) ERROR")
                    print("            Missing file:", filename)
                    exit()                    
                f = open(filename,'r')
                skipcomment(f)
                for iwvl in range(self.nwvl):
                    tmp = f.readline()
                    self.nai_c[isp,iwvl]   = int(tmp.split()[1])
                    self.xsect_c[isp,iwvl] = float(tmp.split()[2])
                for iwvl in range(self.nwvl):
                    if self.nai_c[isp,iwvl]>1:
                        skipcomment(f)
                        tmp = f.readline()
                        for iai in range(self.nai_c[isp,iwvl]):
                            self.ai_c[isp,iwvl,iai] = float(tmp.split()[iai])  
                        for ic in range(self.nc):
                            for it in range(self.nt):
                                for ip in range(self.np):
                                    tmp = f.readline()
                                    for iai in range(self.nai_c[isp,iwvl]):
                                        self.ki_c[isp,iwvl,iai,ip,it,ic] = float(tmp.split()[iai])                                  
                f.close()

            if len(channel_list) > 0:
                self.channels = self.channels[channel_list-1]
                self.wvlband = self.wvlband[:,channel_list-1]
                if self.nsp > 0:
                    self.nai     = self.nai[:,channel_list-1]          
                    self.ki      = self.ki[:,channel_list-1,:,:,:]     
                    self.ai      = self.ai[:,channel_list-1,:]         
                    self.xsect   = self.xsect[:,channel_list-1]      
                if self.nsp_c > 0:
                    self.nai_c     = self.nai_c[:,channel_list-1]          
                    self.ki_c      = self.ki_c[:,channel_list-1,:,:,:,:]     
                    self.ai_c      = self.ai_c[:,channel_list-1,:]         
                    self.xsect_c   = self.xsect_c[:,channel_list-1]      

                self.nwvl    = self.wvlband.shape[1]

        return


    def select_chan_kdis(self, list_chan):

        if not -1 in list_chan:
            ind_chan=[]
            for ichan, chan in enumerate(list_chan):
                ind_chan.append( np.squeeze(np.argwhere(chan == self.channels)) )
            ind_chan = np.array(ind_chan)

            self.nwvl     = len(ind_chan)
            self.wvlband  = self.wvlband[:,ind_chan]
            self.channels = self.channels[ind_chan]
            
            self.nai   = self.nai[:,ind_chan]          #np.zeros((self.nsp,self.nwvl), dtype='int')
            self.ki    = self.ki[:,ind_chan,:,:,:]     #np.zeros((self.nsp,self.nwvl, self.nmaxai,self.np,self.nt))
            self.ai    = self.ai[:,ind_chan,:]         #np.zeros((self.nsp,self.nwvl, self.nmaxai))
            self.xsect = self.xsect[:,ind_chan]        #np.zeros((self.nsp,self.nwvl))
            
            if self.nsp_c > 0:
                self.nai_c   = self.nai_c[:,ind_chan]           # np.zeros((self.nsp_c,self.nwvl), dtype='int')
                self.ki_c    = self.ki_c[:,ind_chan,:,:,:,:]    # np.zeros((self.nsp_c,self.nwvl,self.nmaxai,self.np,self.nt,self.nc))
                self.ai_c    = self.ai_c[:,ind_chan,:]          # np.zeros((self.nsp_c,self.nwvl,self.nmaxai))
                self.xsect_c = self.xsect_c[:,ind_chan]         # np.zeros((self.nsp_c,self.nwvl))

        return 




########################################################################################

class surface(object):

    def __init__(self, name, family, kind, \
                     interp = False,       \
                     bpdf = False,  \
                     temp = 0.0,    \
                     wl=None, alb=None, iso=None, vol=None, geo=None, cv=None, nr=None, ni=None, \
                     test_glitter    = False, \
                     ocean_whitecaps = True,  \
                     _6s_glitter     = False, \
                     shadow          = False, \
                     wdspd=5.0, sdistr=3, azw = 0.0, xsal=34.3, pcl=0.0, wvl_Rsw=-1,  Rsw=-1.0):

        """
        
        sdistr  : 3  = isotropic surface
                : 2  = Anisotropic Gaussian distribution
                : 1  = Anisotropic Gaussian distribution with Gram Charlier series correction

        azw     :  phi_sun - phi_wind (for sdistr = 2 or 1)

        _6s_glitter :   True  = 6SV glitter
                        False = M. I. Mishchenko and L. D. Travis, J. Geophys. Res. 102, 16989-17013 (1997).

        Rsw is an irradiance reflectance

        """

        self.name     = name
        self.temp     = temp
        self.interp   = interp
        self.bpdf     = bpdf
        self.test_glitter = test_glitter
        self.ocean_whitecaps = ocean_whitecaps
        self._6s_glitter = _6s_glitter
        
        if (not self._6s_glitter) and sdistr!=3:
            print("")
            print("")
            print("ERROR : ")
            print("  6S glitter must be used (_6s_glitter = True) if anisotropic surface required")
            print("")
            print("")
            exit()

        if shadow and sdistr!=3:
            print("")
            print("")
            print("ERROR : ")
            print("  Shadow possible only if isotropic ocean surface  ")
            print("")
            print("")
            exit()

        self.family = family
        if self.family not in ['brdf', 'lambert']:
            print("ERROR")
            print("  Surface family must be -brdf- or -lambert-")
            exit()
        if self.family == 'lambert':
            self.kind = kind
            self.wl   = wl[np.argsort(wl)]
            self.alb  = alb[np.argsort(wl)]
        else: 
            self.kind = kind
            if kind not in ['ocean', 'land']:
                print("ERROR")
                print("  Surface family must be -ocean- or -land-")
                exit()
            if kind == 'ocean':
                self.wdspd  = wdspd
                self.sdistr = sdistr
                self.azw    = azw
                self.shadow = shadow
                self.xsal   = xsal
                self.pcl    = pcl
                self.Rsw     = Rsw
                self.wvl_Rsw = wvl_Rsw
            else:
                self.iso = iso[np.argsort(wl)]
                self.vol = vol[np.argsort(wl)]
                self.geo = geo[np.argsort(wl)]
                self.wl  = wl[np.argsort(wl)]
                if bpdf:
                    self.cv = cv[np.argsort(wl)]
                    self.nr = nr[np.argsort(wl)]
                    self.ni = ni[np.argsort(wl)]

    def get_Rsw(self, wl):
        if isinstance(self.Rsw, float) or isinstance(self.Rsw, int):
            Rsw = np.full_like(wl, self.Rsw)
        else:
            if -1 in self.Rsw:
                Rsw = np.full_like(wl, -1)
            Rsw = interp1d(self.wvl_Rsw, self.Rsw, kind=surf_interp, bounds_error=False, fill_value=np.nan )(wl)
        return Rsw

    def get_it(self, wl, bpdf):

        #print self.alb
        #print self.wl
        #print wl

        if self.kind == 'ocean':
            print("surface.get_it() is useless for ocean")
            return
        else:
            if self.family == 'brdf': 
                nwl = len(wl)      
                iso = np.zeros(nwl)
                vol = np.zeros(nwl)
                geo = np.zeros(nwl)
                if bpdf:
                    if not self.bpdf:
                        print("(surface.get_it) ERROR")
                        print("                 BPDF was not defined in surface")
                        exit()
                    cv = np.zeros(nwl)
                    nr = np.zeros(nwl)
                    ni = np.zeros(nwl)
            else:
                nwl = len(wl)      
                alb = np.zeros(nwl)
                     
            if not self.interp:
                for iwl in range(nwl):
                    if wl[iwl] not in self.wl:    
                        print("(surface.get_it) : No wavelength interpolation allowed in surface definition")
                        exit()
                    else:
                        if self.family == 'brdf':                
                            iso[iwl] = self.iso[self.wl == wl[iwl]]
                            vol[iwl] = self.vol[self.wl == wl[iwl]]
                            geo[iwl] = self.geo[self.wl == wl[iwl]]
                            if bpdf:
                                cv[iwl] = self.cv[self.wl == wl[iwl]]
                                nr[iwl] = self.nr[self.wl == wl[iwl]]
                                ni[iwl] = self.ni[self.wl == wl[iwl]]
                        else:
                            alb[iwl] = self.alb[self.wl == wl[iwl]]

            else:     
                if self.family == 'brdf':     
                    fiso = interp1d(self.wl, self.iso, kind=surf_interp)
                    fvol = interp1d(self.wl, self.vol, kind=surf_interp)
                    fgeo = interp1d(self.wl, self.geo, kind=surf_interp)
                    iso  = fiso(wl)
                    vol  = fvol(wl)
                    geo  = fgeo(wl)
                    if bpdf :
                        fcv = interp1d(self.wl, self.cv, kind=surf_interp)
                        fnr = interp1d(self.wl, self.nr, kind=surf_interp)
                        fni = interp1d(self.wl, self.ni, kind=surf_interp)
                        cv = fcv(wl)
                        nr = fnr(wl)
                        ni = fni(wl)
                else:
                    falb = interp1d(self.wl, self.alb, kind=surf_interp)
                    alb  = falb(wl)


        if self.family == 'lambert':               
            return alb
        else:
            if bpdf:
                return iso, vol, geo, cv, nr, ni
            else:
                return iso, vol, geo

########################################################################################

class ptqo_atmosphere(object):

    def __init__(self, artdeco_in, p, t, hum, o3, wldepol, depol, gas, ppmv, surfalt, psurf, interp=True, \
        lat=45.0, newgrid=None, kdis=None, tau_ray = {}, rh_phase="liq", verbose=False, save_ascii=None):
        """
         p          in Pa
         t          in K
         hum        in kg/kg over moist air
         o3         in kg/kg over moist air
         gas        list of gas to be accounted for in absorption
         ppmv       ppmv profile over dry air
                         dictionnary of interp1d( np.log10(p (hPa), ppmv_over_dry_air )
                         for several gases
         surfalt    surface altitude in m
         psurf      surface pressure in Pa
         wvl_depol  Rayleigh depol coeff wavelengths
         depol      Rayleigh depol coeff 
         interp     whether we can interpolate depol
        """

        self.interp  = interp
        self.wldepol = wldepol[np.argsort(wldepol)]
        self.depol   = depol[np.argsort(wldepol)]

        self.hum_orig  = hum
        self.o3_orig   = o3

        self.tau_ray = tau_ray

        if rh_phase in ["liq","liquid"]:
            self.rh_phase = "liq" 
        elif rh_phase in ["ice"]:
            self.rh_phase = "ice" 
                        
        if newgrid == None:
            self.newgrid = None
        else:
            self.newgrid = newgrid

        self.gasabs = gas.copy()

        # do some check for kdis mode
        if artdeco_in.mode == 'kdis':
            if 'none' in gas:
                print("(atmopshere) ERROR")
                print("             kdis mode without absorbing gas !")
                exit()

            for g in gas:
                if (not g in kdis.species) and (not g in kdis.species_c):
                    print("(atmosphere) ERROR")
                    print("             Required gas ", g," for the atmsophere")  
                    print("             is not defined in KDIS ", kdis.model)  
                    exit()

            # h2o required to compute o2 continuum
            if ('o2' in gas) or ('n2' in gas):
                if not 'h2o' in gas:
                    gas.append('h2o')

        if 'none' in gas:
            self.ngas  = 0
        else:
            self.ngas  = len(gas)
                       
        if self.ngas>0:     
            self.gas  = gas
            gas = None
        else:
            self.gas  = 'none'

        h2o_ppmv_moist = get_h2o_ppmv_moist(hum)
        
        altitude, dmair  = calc_hgpl(lat, t, p, psurf, surfalt, h2o_ppmv_moist)
        # alt   in  meters 
        # dmair density of moist air in kg/m3
        
        # for ialt in range(len(p)):
        #     print("   %.2f"%p[ialt]+"   %.2f"%altitude[ialt]+"   %.5f"%dmair[ialt])
        # print(psurf, surfalt)
        # exit()
        
        nalt = len(p)
        self.alt_orig   = np.copy(altitude) * 1e-3 # km
        self.t_orig     = np.copy(t)               # K
        self.p_orig     = np.copy(p) / 100.        # hPa
        self.u_air_orig = np.zeros(nalt)
        
        for ialt in range(nalt):
            if ialt>0:
                if self.alt_orig[ialt-1] <= self.alt_orig[ialt]:
                    for ialt1 in range(ialt+2):
                        print("  %i"%ialt1+"   %.2f"%self.p_orig[ialt1] +"   %.2f"%self.alt_orig[ialt1] )
                    print("(atmospher) ERROR ")
                    print("            Altitude must be strictly decreasing ")
                    exit()
                    
        wv_gpercubicmeter = hum * dmair * 1000.                  #  g/m3
        ddair              = (dmair * 1000.) - wv_gpercubicmeter #  g/m3 density of dry air
        self.u_air_orig = (wv_gpercubicmeter / mh2o * na) + (ddair / mair * na) # particule density of moist air in m-3
        self.u_air_orig = self.u_air_orig / 100.**3. # particule density of moist air in cm-3
        
        if self.ngas>0:    
            self.u_orig = np.zeros((self.ngas, nalt))
            for igas in range(self.ngas):
                if self.gas[igas] == 'h2o':
                    self.u_orig[igas, :] = (wv_gpercubicmeter / mh2o * na) / 100.**3.
                    self.rh_orig = get_rh(self.p_orig, self.t_orig, self.u_air_orig, self.u_orig[igas, :], self.rh_phase)
                elif self.gas[igas] == 'o3':
                    self.u_orig[igas, :] = (o3 * dmair * 1000. / mo3 * na) / 100.**3.
                else:
                    
                    if not self.gas[igas] in list(ppmv.keys()):
                        print("(atmosphere) ERROR")
                        print("             you must provide a ppmv profile")
                        print("             for gas ", self.gas[igas])
                        print("             ")
                        exit()
                                                           
                    self.u_orig[igas, :] = self.u_air_orig * 1e-6 * ppmv[self.gas[igas]](np.log10 (self.p_orig)) * (1.0 - 1e-6 * h2o_ppmv_moist)
                          
        else:
            self.u_orig = None


        if verbose or (save_ascii is not None):
                
            if verbose:   
                for ialt in range(nalt-1,-1,-1):
                    s = "   %.3f"%self.alt_orig[ialt]+"   %.2f"%self.p_orig[ialt]+"   %.2f"%self.t_orig[ialt]+"   %.5e"%self.u_air_orig[ialt]
                    for igas in range(self.ngas):
                        s = s + "   %.5e"%(self.u_orig[igas,ialt] / self.u_air_orig[ialt] * 1e6)
                    print(s)    
                s ="   alt      p         T          uair    "
                for igas in range(self.ngas):
                    s = s + "          %s"%self.gas[igas]
                print(s)
                print(" column of WV (kg/m**2) = ", -np.trapz(wv_gpercubicmeter, x=self.alt_orig*1000.) / 1000.)
                column_air = -np.trapz(self.u_air_orig, x=self.alt_orig*1000.)
                for igas, gas in enumerate(self.gas):
                    print(   "   ppmv of %5s is %18.9f"%(gas,  -np.trapz(self.u_orig[igas,:], x=self.alt_orig*1000.)/ column_air * 1e6))
                
            if verbose:
                print("\n")
                print("densities")
                print("\n")

            if save_ascii is not None:
                f = open(save_ascii,"w")

            for ialt in range(nalt):
                s = "   %8.3f"%self.alt_orig[ialt]+"   %15.6f"%self.p_orig[ialt]+"   %.2f"%self.t_orig[ialt]+"   %.5e"%self.u_air_orig[ialt]
                for igas in range(self.ngas):
                    s = s + "   %.5e"%(self.u_orig[igas,ialt])
                print(s)    
                if save_ascii is not None:
                    f.write(s+" \n")
            s ="      alt              p          T          uair "
            for igas in range(self.ngas):
                s = s + "           %s"%self.gas[igas]
            print(s)
            if save_ascii is not None:
                f.write(s+" \n")

            if save_ascii is not None:
                f.close()

        return

    
    def regrid(self, debug=False):
        
        """ Set up (altitude) grid specified in format 'start[step1]stop1[step2]stop' or similar. """
        
        # this regrid the atmosphere
        # vertical profil
        
        if self.newgrid == None:
            alt    = np.copy(self.alt_orig)
            t      = np.copy(self.t_orig)
            p      = np.copy(self.p_orig)
            u_air  = np.copy(self.u_air_orig)
            if self.ngas > 0:
                u      = np.copy(self.u_orig)
            else:
                u = None
                    
        else:

            alt = parseGridSpec(self.newgrid)
            alt = np.sort(alt)[::-1]

            ft    = interp1d(self.alt_orig, self.t_orig,             kind=atm_interp, assume_sorted=False)
            fp    = interp1d(self.alt_orig, np.log(self.p_orig),     kind=atm_interp, assume_sorted=False)
            fuair = interp1d(self.alt_orig, np.log(self.u_air_orig), kind=atm_interp, assume_sorted=False)

            t     = ft(alt)
            p     = np.exp(fp(alt))
            u_air = np.exp(fuair(alt))
            if self.ngas > 0:
                u = np.zeros( (self.ngas, len(alt)) ) 
                for i in range(self.ngas):
                    fu = interp1d(self.alt_orig, np.log(self.u_orig[i,:]), kind=atm_interp, assume_sorted=False)
                    u[i, :] = np.exp(fu(alt))
                    # cases where density is null will give a NaN because of the log(0.0)
                    u[i,:][np.isnan(u[i,:])] = 0.0
            else:
                u = None
            
        return  alt, t, p, u_air, u

    
    def get_depol(self, wl):

        #print wl
        #print self.depol
        #print self.wldepol

        nwl = len(wl)
        depol = np.zeros(nwl)
        if not self.interp:
            for iwl in range(nwl):
                    if wl[iwl] not in self.wldepol:
                        print("")
                        print("(atmo.get_depol) : No wavelength interpolation allowed in depolarization definition")
                        print("")
                        exit()
                    else:
                        depol[iwl] = self.depol[self.wldepol == wl[iwl]]
        else:     
            fdepol = interp1d(self.wldepol, self.depol, kind=depol_interp)
            depol  = fdepol(wl)

        return depol
    
    
########################################################################################


class atmosphere(object):    
    def __init__(self, artdeco_in, dir_data, atm_file, gas, ppmv, fmt_atm, wldepol, depol, interp=True, kdis=None, newgrid=None, DU_o3 = -1, P0 = -1, water = -1, tau_ray = {}, rh_phase="liq", warn=False, scale_t_rh_cste=True):
         
        # water must be given in g/cm2
        # DU_O3 is Dobson unit
              
        #print "Set atmosphere class..."

        self.warn=warn

        self.tau_ray = tau_ray
        
        if newgrid == None:
            self.newgrid = None
        else:
            self.newgrid = newgrid

        self.interp  = interp
        self.wldepol = wldepol[np.argsort(wldepol)]
        self.depol   = depol[np.argsort(wldepol)]

        self.scale_t_rh_cste = scale_t_rh_cste
        
        if rh_phase in ["liq","liquid"]:
            self.rh_phase = "liq" 
        elif rh_phase in ["ice"]:
            self.rh_phase = "ice" 

        self.lbl_atm = False
        if artdeco_in.mode == 'lbl':
            if (not "lbl" in fmt_atm):
                print("(atmosphere)  ERROR ")
                print("              If artdeco mode is lbl ")
                print("              you must provide an atm file with lbl format: -lbl_forum- or -lbl_liradtran-")
                exit()
            else:
                self.lbl_atm = True

        self.gasabs = ['']*len(gas)
        self.gasabs[:] = gas[:]

        # do some check for kdis mode
        if artdeco_in.mode == 'kdis':
            if 'none' in gas:
                print("(atmopshere) ERROR")
                print("             kdis mode without absorbing gas !")
                exit()

            for g in gas:
                if (not g in kdis.species) and (not g in kdis.species_c):
                    print("(atmosphere) ERROR")
                    print("             Required gas ", g," for the atmsophere")  
                    print("             is not defined in KDIS ", kdis.model)  
                    exit()
            # h2o required to compute o2 continuum
            if ('o2' in gas) or ('n2' in gas):
                if not 'h2o' in gas:
                    gas.append('h2o')
                    ppmv.append(-1)
                    
        if 'none' in gas:
            nppmv = 0
            self.ngas  = 0
        else:
            nppmv = len(ppmv)
            self.ngas  = len(gas)
                       
        if nppmv != self.ngas : 
            print("(atmosphere) nppmv != ngas")
            exit()
            
        if self.ngas>0:     
            self.gas  = gas
            gas = None
            self.ppmv = ppmv
            ppmv = None
        else:
            self.gas  = 'none'


        if self.lbl_atm :

            self.ngas = 0
            
            if len(artdeco_in.wavel)!=2:
                print ("(atmosphere)  " )
                print ("(atmosphere)  " )
                print ("(atmosphere)  in lbl mode two boundary wavelengths must be provided in -artdeco_in.wavel- " )
                print ("(atmosphere)  " )
                print ("(atmosphere)  " )
                exit()

                
            if fmt_atm == "lbl_iprt":
                
                print(" (atmosphere) read ", dir_data+atm_file)
                
                filename = dir_data+atm_file
                if not os.path.isfile(filename):
                    print("(atmosphere) ERROR")
                    print("            Missing file:", filename)
                    exit()
                f = open(filename,'r')
                skipcomment(f)
                tmp = f.readline()
                nalt = int(tmp.split()[0])
                skipcomment(f)
                tmp = f.readline()
                nwvl = int(tmp.split()[0])
                skipcomment(f)
                
                self.alt_orig   = np.zeros(nalt)
                self.t_orig     = np.zeros(nalt)
                self.p_orig     = np.zeros(nalt)
                self.u_air_orig = np.full(len(self.alt_orig), np.nan)

                wvl_tauabsgas = np.zeros(nwvl)
                dtauabsgas     = np.zeros((nalt-1,nwvl))

                tmp = f.readline()
                for iwvl in range(nwvl):
                    wvl_tauabsgas[iwvl] = float(tmp.split()[iwvl])
                skipcomment(f)
                
                for ialt in range(nalt):
                    tmp = f.readline()                
                    self.alt_orig[ialt]   = float(tmp.split()[0])
                    self.p_orig[ialt]     = float(tmp.split()[1])
                    self.t_orig[ialt]     = float(tmp.split()[2])
                    if ialt>0:
                        if  self.alt_orig[ialt] >= self.alt_orig[ialt-1]:
                            print("(atmosphere) ERROR")
                            print("             with -lbl_iprt- fmt, alt must be sorted decreasing ")
                            print("")
                            exit()
                    for iwvl in range(nwvl):
                        dtauabsgas[ialt-1,iwvl] = float(tmp.split()[3+iwvl])

                f.close()

                
                if artdeco_in.wavel[0] == -1:
                    imin = 0
                else:
                    imin = find_nearest(wvl_tauabsgas, artdeco_in.wavel[0])

                if artdeco_in.wavel[1] == -1:
                    imax = 0
                else:
                    imax = find_nearest(wvl_tauabsgas, artdeco_in.wavel[1])
             
                self.gas =[]
                
                self.wvl_tauabsgas = wvl_tauabsgas[imin:imax+1]
                self.tauabsgas     = dtauabsgas[:,imin:imax+1]
                
                self.alt_tauabsgas = self.alt_orig


                

            if fmt_atm == "lbl_libradtran":

                print("This must evolve to include, P,T,alt in the file without modifiyng the libRatran format basis...")
                print("   file mustthen include at the / : wvl, z, tau, wvlmin, wvlmax")
                exit()
                
                # read line-by-line opacities

                print(" (atmosphere) read ", dir_data+atm_file)
                self.tauabsgas     = read_nc(dir_data+atm_file, 'tau')
                self.alt_tauabsgas = read_nc(dir_data+atm_file, 'z')
                self.wvl_tauabsgas = read_nc(dir_data+atm_file, 'wvl') * 1e-3 # nm to microns


            elif fmt_atm == "lbl_forum":

                
                f = h5py.File(dir_data+atm_file,"r")
                self.alt_orig   = np.copy(f["Atmos"]["Altitude"])
                self.t_orig     = np.copy(f["Atmos"]["Temperature"])
                self.p_orig     = np.copy(f["Atmos"]["Pressure"])
                self.u_air_orig = np.full(len(self.alt_orig), np.nan)
                
                self.alt_tauabsgas = self.alt_orig

                wn_tauabsgas = np.copy(f["Spectral_Informations"]["Wavenumber"]) 

                if f["Spectral_Informations"]["Wavenumber"][0] > f["Spectral_Informations"]["Wavenumber"][1]:
                    print ("(atmosphere)  " )
                    print ("(atmosphere)  " )
                    print ("(atmosphere)  In  ", dir_data+atm_file)
                    print ("(atmosphere)   Wavenumber should be sorted in increasing order" )
                    print ("(atmosphere)  " )
                    exit()
    
                if artdeco_in.wavel[0] == -1:
                    imax = find_nearest(wn_tauabsgas, f["Spectral_Informations"]["Wavmax"]) 
                else:
                    imax = find_nearest(wn_tauabsgas, 1e4 / artdeco_in.wavel[0])
                    
                if artdeco_in.wavel[1] == -1:
                    imin = find_nearest(wn_tauabsgas, f["Spectral_Informations"]["Wavmin"]) 
                else:
                    imin = find_nearest(wn_tauabsgas, 1e4 / artdeco_in.wavel[1]) 
                
                self.tauabsgas = np.zeros((imax-imin+1, len(self.alt_orig)-1))

                gas_list = list(f["Mol_Optical_Depth"].keys())

                if "all" in self.gas:
                    self.gas   = gas_list

                for g in self.gas:
                    if not g in gas_list:
                        print("(atmosphere) ERROR")
                        print("             Required gas ", g," for the atmsophere")  
                        print("             is not defined in ", dir_data+atm_file)  
                        exit()        
                
                for gas_name in gas_list:
                    if gas_name in self.gas:
                        self.tauabsgas[:,:] = self.tauabsgas[:,:] + f["Mol_Optical_Depth"][gas_name][imin:imax+1,:]
                                
                f.close()


                self.gas =[]
                
                # sort in wvl increasing order
                self.wvl_tauabsgas = (1e4 / wn_tauabsgas[imin:imax+1])[::-1]
                self.tauabsgas     = np.transpose(self.tauabsgas[::-1,:])

                # import matplotlib.pyplot as plt
                # plt.plot(1e4/self.wvl_tauabsgas, np.exp(-np.sum(self.tauabsgas, axis=0)) )
                # plt.show()
                # exit()

                
                
        else:
            
            if fmt_atm=='ascii':

                filename = dir_data+atm_file
                if not os.path.isfile(filename):
                    print("(atmosphere) ERROR")
                    print("            Missing file:", filename)
                    exit()
                f = open(filename,'r')
                skipcomment(f)
                tmp = f.readline()
                nalt = int(tmp.split()[0])
                skipcomment(f)
                tmp = f.readline()
                nspecies = int(tmp.split()[0])
                species=[]
                skipcomment(f)
                if nspecies > 0:
                    for i in range(nspecies):
                        tmp = f.readline()
                        species.append(tmp.split()[0])
                else:
                    tmp = f.readline()
                #names = ['z', 'p', 't', 'uair']+species

                #print(filename)

                skipcomment(f)
                self.alt_orig   = np.zeros(nalt)
                self.t_orig     = np.zeros(nalt)
                self.p_orig     = np.zeros(nalt)
                self.u_air_orig = np.zeros(nalt)                           

                u_species       = {}
                for isp in range(nspecies):
                    u_species[species[isp]] = np.zeros(nalt)

                for ialt in range(nalt):
                    tmp = f.readline()                
                    self.alt_orig[ialt]   = float(tmp.split()[0])
                    if ialt>0:
                        if self.alt_orig[ialt-1] <= self.alt_orig[ialt]:
                            print("(atmospher) ERROR ")
                            print("            Altitude must be strictly decreasing ")
                            exit()
                    self.p_orig[ialt]     = float(tmp.split()[1])
                    self.t_orig[ialt]     = float(tmp.split()[2])
                    self.u_air_orig[ialt] = float(tmp.split()[3])
                    for isp in range(nspecies):
                        u_species[species[isp]][ialt] = float(tmp.split()[4+isp])

                f.close()           

                if self.ngas>0:    
                    self.u_orig = np.zeros((self.ngas, nalt))
                    for igas in range(self.ngas):
                        if self.gas[igas] in species:
                            self.u_orig[igas, :] = u_species[self.gas[igas]]
                            if self.ppmv[igas] != -1:
                                print("(atmosphere) ERROR")
                                print("             ppmv must be -1 for ",self.gas[igas])
                                print("             cause it is already defined in file :")
                                print("             ",atm_path)
                                exit()

                        else:
                            if self.ppmv[igas] == -1:
                                print("(atmosphere) ERROR")
                                print("             you must provide a ppmv not = -1")
                                print("             for gas ", self.gas[igas])
                                print("             cause it is not defined in file")
                                print("             ",atm_path)
                                exit()
                            self.u_orig[igas, :] = self.u_air_orig * 1e-6 * self.ppmv[igas]
                else:
                    self.u_orig = None

            else:
                print("(atmosphere) ERROR : atmosphere file format not supported") 
            

        # if rh100_bounds is set, the relative humidity in cloud will be set to 100% when calling self.regrid()
        # Doing that, the WVC is either kept constant or not depending on self.rh100_wvc_cste 
        self.rh100_bounds = []
        self.rh100_wvc_cste = True 
                
        self.DU_o3 = 0
        self.water = 0        
        # compute DU and water vapor content
        for igas in range(self.ngas):
            if self.gas[igas] == 'o3':
                column = np.abs(np.trapz(self.u_orig[igas, :], x = self.alt_orig)) * 1e5                
                self.DU_o3 = column / 2.69e16             # DU
            if self.gas[igas] == 'h2o':
                column = np.abs(np.trapz(self.u_orig[igas, :], x = self.alt_orig)) * 1e5
                self.water = column * mh2o / cste.N_A   # g/cm2


                
        # eventually scale ozone integrated quantity 
        if ('o3' in self.gas) and (DU_o3!=-1):
            for igas in range(self.ngas):
                if self.gas[igas] == 'o3':
                    self.scale_o3 = DU_o3 / self.DU_o3
                    self.DU_o3    = DU_o3
                    self.u_orig[igas, :] = self.u_orig[igas, :] * self.scale_o3  
                    break

        # eventually scale water vapor integrated quantity 
        if ('h2o' in self.gas) and (water!=-1):
            for igas in range(self.ngas):
                if self.gas[igas] == 'h2o':

                    if water == 0.0:

                        self.water           = 0.0
                        self.u_orig[igas, :] = 0.0

                    else:
                        rh_before = get_rh(self.p_orig, self.t_orig, self.u_air_orig, self.u_orig[igas,:], self.rh_phase)

                        #t_orig_before = np.copy(self.t_orig)

                        scale_h2o = water / self.water 
                        self.water     = water
                        self.u_orig[igas, :] = self.u_orig[igas, :] * scale_h2o

                        if self.scale_t_rh_cste:
                            # we want to keep Rh constant, we modify the profile temperature according to the N_H2O modification
                            self.t_orig = get_t_from_rh_u(self.p_orig,rh_before,self.u_air_orig, self.u_orig[igas,:], self.rh_phase)

                        # import matplotlib.pyplot as plt
                        # plt.figure("Temperature")
                        # plt.plot(t_orig_before-273.15 ,  self.alt_orig, label="before")
                        # plt.plot( self.t_orig-273.15,  self.alt_orig, "--", label="after")
                        # plt.legend(loc="best")
                        # plt.figure("Relative humidity")
                        # plt.plot(rh_before,  self.alt_orig)
                        # plt.plot(get_rh(self.p_orig, self.t_orig, self.u_air_orig, self.u_orig[igas,:]),  self.alt_orig, "--")
                        # plt.show()

                    
                    break

        # get relative humidity profile        
        if ('h2o' in self.gas):
            for igas in range(self.ngas):
                if self.gas[igas] == 'h2o':
                    self.rh_orig  = get_rh(self.p_orig, self.t_orig, self.u_air_orig, self.u_orig[igas,:], self.rh_phase)
                    if self.warn and len(self.rh_orig[self.rh_orig>limit_rh])>0:
                        print("(atmosphere) WARNING ")
                        print("             Relative humidity superior to %.2f percent"%limit_rh)
                        for ialt in range(len(self.alt_orig)):
                            s = "%3i"%ialt+'   %10.3f'%self.alt_orig[ialt]+'   %10.3f'%self.p_orig[ialt]+'   %10.3f'%self.t_orig[ialt]+'   %10.2f'%self.rh_orig[ialt]
                            print(s)
                    break
                        
        self.cloudy_water = -1.0
        
        # scale P0
        # Note : This is a brut force scaling and may not be adapted in all situations
        #        The gases concentration will be scaled too (exept for the ozone and H2O)
        #        Their relative contents remain unchanged 
        self.P0 = P0    
        if P0 != -1.0:
            if self.lbl_atm:
                print("(atm regird) ERROR")
                print("             No P0 scale is allowed if lbl mode")
            fact = P0 / np.amax(self.p_orig)            
            self.p_orig       = self.p_orig * fact
            self.u_air_orig   = self.u_air_orig * fact
            if self.ngas > 0:
                for igas in range(self.ngas):
                    if (self.gas[igas] != 'o3') and (self.gas[igas] != 'h2o'):
                        self.u_orig[igas, :]   = self.u_orig[igas, :] * fact 
       

        if self.alt_orig[0]<self.alt_orig[1]:
            print("(atm regird) ERROR")
            print("             alt_orig must be sorted in decreasing order")
            exit()


            
    def set_P0(self, P0):

        # scale P0
        # Note : This is a brut force scaling and may not be adapted in all situations
        #        The gases concentration will be scaled too (exept for the ozone and H2O)
        #        Their relative contents remain unchanged 
        self.P0 = P0    
        if P0 != -1.0:
            if self.lbl_atm:
                print("(atm regird) ERROR")
                print("             No P0 scale is allowed if lbl mode")
            fact = P0 / np.amax(self.p_orig)            
            self.p_orig       = self.p_orig * fact
            self.u_air_orig   = self.u_air_orig * fact
            if self.ngas > 0:
                for igas in range(self.ngas):
                    if (self.gas[igas] != 'o3') and (self.gas[igas] != 'h2o'):
                        self.u_orig[igas, :]   = self.u_orig[igas, :] * fact
                        
        return
        
            
    def get_alt(self, pressure):
        f = interp1d(np.log(self.p_orig), self.alt_orig, bounds_error=False, fill_value=-32768)
        return f(np.log(pressure))

    def get_t(self, alt):
        f = interp1d(self.alt_orig, self.t_orig, bounds_error=False, fill_value=-32768)
        return f(alt)
          
    def get_press(self, alt):
        f = interp1d(self.alt_orig, np.log(self.p_orig), bounds_error=False, fill_value=np.nan)
        return np.exp(f(alt))


    def get_depol(self, wl):

        #print wl
        #print self.depol
        #print self.wldepol

        nwl = len(wl)
        depol = np.zeros(nwl)
        if not self.interp:
            for iwl in range(nwl):
                    if wl[iwl] not in self.wldepol:
                        print("")
                        print("(atmo.get_depol) : No wavelength interpolation allowed in depolarization definition")
                        print("")
                        exit()
                    else:
                        depol[iwl] = self.depol[self.wldepol == wl[iwl]]
        else:     
            fdepol = interp1d(self.wldepol, self.depol, kind=depol_interp)
            depol  = fdepol(wl)

        return depol


    
    def regrid(self, debug=False):
        
        """ Set up (altitude) grid specified in format 'start[step1]stop1[step2]stop' or similar. """
        
        # this regrid the atmosphere
        # vertical profil
        
        if self.newgrid == None:

            alt    = np.copy(self.alt_orig)
            if  len(self.rh100_bounds) == 0:
                t      = self.t_orig
                p      = self.p_orig
                u_air  = self.u_air_orig
                if self.ngas > 0:
                    u      = self.u_orig
                    if "h2o" in self.gas:
                        rh = self.rh_orig
                else:
                    u = None
                    
        else:

            if self.lbl_atm:
                print("(atm regird) ERROR")
                print("             No regrid is allowed if lbl mode")
                exit()
                            
            alt = parseGridSpec(self.newgrid)
            alt = np.sort(alt)[::-1]

            if  len(self.rh100_bounds) == 0:
                ft    = interp1d(self.alt_orig, self.t_orig,             kind=atm_interp, assume_sorted=False)
                fp    = interp1d(self.alt_orig, np.log(self.p_orig),     kind=atm_interp, assume_sorted=False)
                fuair = interp1d(self.alt_orig, np.log(self.u_air_orig), kind=atm_interp, assume_sorted=False)

                t     = ft(alt)
                p     = np.exp(fp(alt))
                u_air = np.exp(fuair(alt))
                if self.ngas > 0:
                    u = np.zeros( (self.ngas, len(alt)) ) 
                    for i in range(self.ngas):
                        fu = interp1d(self.alt_orig, np.log(self.u_orig[i,:]), kind=atm_interp, assume_sorted=False)
                        u[i, :] = np.exp(fu(alt))
                        # cases where density is null will give a NaN because of the log(0.0)
                        u[i,:][np.isnan(u[i,:])] = 0.0
                        if self.gas[i] == 'h2o':
                            #u[i,:] = u[i,:] * self.water / (np.abs(np.trapz( u[i,:], x = alt)) * 1e5 * mh2o / cste.N_A)
                            rh = get_rh(p,t,u_air,u[i,:], self.rh_phase)
                else:
                    u = None

       
                    
        if len(self.rh100_bounds) > 0:

            
            # Note: if wvc_cste=True,
            #       the Rh in cloud to 100% is done
            #       keeping WVC constant.
            #       We first set Rh to 100% incloud, then:
            #             - If the amont of WVC below cloud permit it,
            #               we simply scale it (i.e. we transfer the WV from below the cloud to
            #               within the cloud).
            #             - Otherwise, we set the WVC below the cloud to 0.0 and scale the
            #               WVC in and above the cloud
            
            # first add levels corresponding to the boudaries
            for ib in range(len(self.rh100_bounds)):
                alt = np.append(alt, self.rh100_bounds[ib])
                # add levels for integration purpose
                alt = np.append(alt, self.rh100_bounds[ib]+1e-5) # add a level right above the cloud level (1 cm)
                alt = np.append(alt, self.rh100_bounds[ib]-1e-5) # add a level right below the cloud level (1 cm)
                
            alt   = np.sort(np.unique(alt))[::-1]
            alt   = alt[alt>=np.min(self.alt_orig)]
            alt   = alt[alt<=np.max(self.alt_orig)]

            # redo all interpolations on the corresponding grid
            ft    = interp1d(self.alt_orig, self.t_orig,             kind=atm_interp, assume_sorted=False)
            fp    = interp1d(self.alt_orig, np.log(self.p_orig),     kind=atm_interp, assume_sorted=False)
            fuair = interp1d(self.alt_orig, np.log(self.u_air_orig), kind=atm_interp, assume_sorted=False)
            t     = ft(alt)
            p     = np.exp(fp(alt))
            u_air = np.exp(fuair(alt))
            if self.ngas > 0:
                u     = np.zeros( (self.ngas, len(alt)) ) 
                for i in range(self.ngas):
                    fu = interp1d(self.alt_orig, np.log(self.u_orig[i,:]), kind=atm_interp, assume_sorted=False)
                    u[i, :] = np.exp(fu(alt))
                    # cases where density is null will give a NaN because of the log(0.0)
                    u[i,:][np.isnan(u[i,:])] = 0.0
                    if self.gas[i] == 'h2o':
                        #u[i,:] = u[i,:] * self.water / (np.abs(np.trapz( u[i,:], x = alt)) * 1e5 * mh2o / cste.N_A)
                        rh = get_rh(p,t,u_air,u[i,:], self.rh_phase)

            flag_cloudy = np.full(len(alt),False)
            for igas in range(self.ngas):                
                if self.gas[igas] == 'h2o':
                    
                    rh_before    = np.copy(rh)    
                    u_h2o_before = np.copy(u[igas,:])
                    
                    # we set the relative humidity to 100% into the cloud.
                    # If Rh (clear sky) was already greater than 100% some where in the
                    # profil, we set rh in cloud to that maximum value
                    for ib in range(len(self.rh100_bounds)):
                        rh_cloud = 100.0
                        if np.nanmax(rh) > rh_cloud:
                            rh_cloud = np.nanmax(rh)
                        rh[(alt >= np.min(self.rh100_bounds[ib])) & (alt <= np.max(self.rh100_bounds[ib]))]          = rh_cloud
                        flag_cloudy[(alt >= np.min(self.rh100_bounds[ib])) & (alt <= np.max(self.rh100_bounds[ib]))] = True

                    # water vapor density profile
                    u_h2o        = get_u_from_rh(p,t,u_air,rh, self.rh_phase)
                    
                    # print("(atmosphere regrid) rh ")
                    # for ialt in range(len(alt)):
                    #     s = "%3i"%ialt+'   %10.3f'%alt[ialt]+'   %10.3f'%p[ialt]+'   %10.3f'%t[ialt]+'   %10.2f'%rh[ialt]+'   %10.2f'%rh_before[ialt]
                    #     print(s)
                    
                    if self.rh100_wvc_cste:
                        
                        max_alt_cloud = np.nanmax(alt[flag_cloudy==True])

                        # print(self.water)
                        # print(np.abs(np.trapz(u_h2o_before, x = alt)) * 1e5 * mh2o / cste.N_A)
                        # # print(np.abs(np.trapz(u_h2o, x = alt)) * 1e5 * mh2o / cste.N_A)
                        # print(np.abs(np.trapz(self.u_orig[igas, :], x = self.alt_orig)) * 1e5* mh2o / cste.N_A)
                        # exit()

                        # WVC into cloudy part
                        u_in = np.copy(u_h2o)
                        u_in[flag_cloudy==False] = 0.0 
                        water_in_cloud = np.abs(np.trapz(u_in, x = alt)) * 1e5 * mh2o / cste.N_A

                        # WVC in clear part and above the cloud (above highest cloudy level)
                        u_out_up = np.copy(u_h2o)
                        u_out_up[(flag_cloudy==True) | (alt<=max_alt_cloud)] = 0.0 
                        water_out_up_cloud = np.abs(np.trapz(u_out_up, x = alt)) * 1e5 * mh2o / cste.N_A

                        u_out_low = np.copy(u_h2o)
                        u_out_low[(flag_cloudy==True) | (alt>max_alt_cloud)] = 0.0 
                        water_out_low_cloud = np.abs(np.trapz(u_out_low, x = alt)) * 1e5 * mh2o / cste.N_A

                        # for ialt in range(len(alt)):
                        #     s = "%3i"%ialt+'   %10.3f'%alt[ialt]+'   %10.3e'%u_out_low[ialt]+'   %10.3e'%u_out_up[ialt]+'   %10.3e'%u_in[ialt]+'   %10.2f'%rh[ialt]
                        #     print(s)

                        # print("above=",water_out_up_cloud)
                        # print("in   =",water_in_cloud)
                        # print("low  =",water_out_low_cloud)

                        # WVC into cloudy part before
                        u_in_before = np.copy(u_h2o_before)
                        u_in_before[flag_cloudy==False] = 0.0 
                        water_in_before_cloud = np.abs(np.trapz(u_in_before, x = alt)) * 1e5 * mh2o / cste.N_A

                        # # WVC in clear part and above the cloud (above highest cloudy level)
                        # u_out_before_up = np.copy(u_h2o_before)
                        # u_out_before_up[(flag_cloudy==True) | (alt<=max_alt_cloud)] = 0.0 
                        # water_out_before_up_cloud = np.abs(np.trapz(u_out_before_up, x = alt)) * 1e5 * mh2o / cste.N_A

                        # u_out_before_low = np.copy(u_h2o_before)
                        # u_out_before_low[(flag_cloudy==True) | (alt>max_alt_cloud)] = 0.0 
                        # water_out_before_low_cloud = np.abs(np.trapz(u_out_before_low, x = alt)) * 1e5 * mh2o / cste.N_A

                        # print("above_before=",water_out_before_up_cloud)
                        # print("in_before   =",water_in_before_cloud)
                        # print("low_before  =",water_out_before_low_cloud)

                        delta_wv_in = water_in_cloud - water_in_before_cloud
                        if (delta_wv_in<0.0):
                            print("(atm regird) ERROR")
                            print("             delta_wv_in < 0")
                            exit()

                        if (delta_wv_in < water_out_low_cloud):
                            ch_up = False
                            u_out_low = u_out_low * (water_out_low_cloud-delta_wv_in) / water_out_low_cloud
                        else:
                            ch_up = True
                            u_out_low[:] = 0
                            scale        = self.water / (water_in_cloud+water_out_up_cloud) 
                            u_out_up = u_out_up * scale
                            u_in     = u_in * scale

                        u_h2o[:] = 0.0    
                        u_h2o[u_out_up!=0.0]  = u_out_up[u_out_up!=0.0]
                        u_h2o[u_out_low!=0.0] = u_out_low[u_out_low!=0.0]                    
                        u_h2o[u_in !=0.0]     = u_in[u_in!=0.0]


                    u[igas,:] = u_h2o
                    rh = get_rh(p,t,u_air,u[igas,:], self.rh_phase)
                        
                    column = np.abs(np.trapz(u[igas, :], x = alt)) * 1e5
                    self.cloudy_water = column * mh2o / cste.N_A   # g/cm2

                    
                    if debug:
                        print("(atmosphere regrid) set rh100 ")
                        for ialt in range(len(alt)):
                            s = "%3i"%ialt+'   %10.3f'%alt[ialt]+'   %10.3f'%p[ialt]+'   %10.3f'%t[ialt]+'   %10.2f'%rh[ialt]+'   %10.2f'%rh_before[ialt]+'   %10.2f'%u_h2o[ialt]+'   %10.2f'%u_h2o_before[ialt]
                            print(s)
                        print("WVC        =", self.cloudy_water)
                        print("WVC before =", self.water )
                        
                        
                    # print("(atmosphere regrid) rh ")
                    # for ialt in range(len(alt)):
                    #     s = "%3i"%ialt+'   %10.3f'%alt[ialt]+'   %10.3f'%p[ialt]+'   %10.3f'%t[ialt]+'   %10.2f'%rh[ialt]+'   %10.2f'%rh_before[ialt]
                    #     print(s)

                    # print(self.cloudy_water)
                    # print(self.water)
                    
                    # for ialt in range(len(alt)):
                    #     s = "%3i"%ialt+'   %10.3f'%alt[ialt]+'   %10.3e'%u_h2o_before[ialt]+'   %10.3e'%u_h2o[ialt]+'   %10.2f'%rh_before[ialt]+'   %10.2f'%rh[ialt]
                    #     print(s)


                    # import matplotlib.pyplot as plt
                    # plt.figure("density")
                    # plt.ylim([1013.0, 100.])
                    # plt.xlabel("$n_{H2O}$")
                    # plt.plot(u_h2o_before, p, label="before cloud add")
                    # plt.plot(u_h2o, p, label="after cloud add")
                    # plt.annotate('wvc after  =%.2f'%self.cloudy_water,
                    #              xy=(0.6, .6), xycoords='figure fraction',
                    #              horizontalalignment='left', verticalalignment='top',
                    #              fontsize=13)
                    # plt.annotate('wvc before =%.2f'%self.water,
                    #              xy=(0.6, .7), xycoords='figure fraction',
                    #              horizontalalignment='left', verticalalignment='top',
                    #              fontsize=13)
                    # plt.legend(loc="best")
                    # plt.figure("Rh")
                    # plt.ylim([1013.0, 100.])
                    # plt.xlabel("$R_H$")
                    # plt.plot(rh_before, p, label="before changing $R_H$")
                    # plt.plot(rh, p, label="after changing $R_H$")
                    # plt.legend(loc="best")
                    # plt.show()

                    
        if  self.warn and len(rh[rh>limit_rh])>0:
            print("(atmosphere regrid) WARNING ")
            print("                    Relative humidity superior to %.2f percent"%limit_rh)
            for ialt in range(len(alt)):
                s = "%3i"%ialt+'   %10.3f'%alt[ialt]+'   %10.3f'%p[ialt]+'   %10.3f'%t[ialt]+'   %10.2f'%rh[ialt]
                print(s)

                
        return  alt, t, p, u_air, u

    
    
    def apply_scale_O3(self, DU_o3):
        if ('o3' in self.gas) and (DU_o3!=-1):
            for igas in range(self.ngas):
                if self.gas[igas] == 'o3':
                    self.scale_o3 = DU_o3 / self.DU_o3
                    self.DU_o3    = DU_o3
                    self.u_orig[igas, :] = self.u_orig[igas, :] * self.scale_o3  
                    break
        return


    def apply_scale_h2o(self, water):
        
        if ('h2o' in self.gas) and (water!=-1):
            for igas in range(self.ngas):
                if self.gas[igas] == 'h2o':

                    if water == 0.0:

                        self.water           = 0.0
                        self.u_orig[igas, :] = 0.0

                    else:

                        rh_before = get_rh(self.p_orig, self.t_orig, self.u_air_orig, self.u_orig[igas,:], self.rh_phase)

                        scale_h2o = water / self.water 
                        self.water     = water
                        self.u_orig[igas, :] = self.u_orig[igas, :] * scale_h2o  

                        if self.scale_t_rh_cste:
                            # we want to keep Rh constant, we modify the profile temperature according to the N_H2O modification
                            self.t_orig = get_t_from_rh_u(self.p_orig,rh_before,self.u_air_orig, self.u_orig[igas,:], self.rh_phase)

                        if  self.warn and len(self.rh_orig[self.rh_orig>limit_rh])>0:
                            print("(atmosphere) WARNING ")
                            print("             Relative humidity superior to %.2f percent"%limit_rh)
                            for ialt in range(len(self.alt_orig)):
                                s = "%3i"%ialt+'   %10.3f'%self.alt_orig[ialt]+'   %10.3f'%self.p_orig[ialt]+'   %10.3f'%self.t_orig[ialt]+'   %10.2f'%self.rh_orig[ialt]
                                print(s)
                    break
        return

    



    
    def get_ray_od(self, p_prof, wvlarr):

        if len(self.tau_ray.keys())==0:
            print("No tau_ray dict was set when initializing the atmosphere class")
            exit()

        tau_od = np.zeros( (len(p_prof), len(wvlarr)) )

        fray = interp1d(self.tau_ray['wvl'], np.log(self.tau_ray['tau']), bounds_error=True)

        for iwvl, wvl in enumerate(wvlarr):
            tau_tmp = np.exp(fray(wvl))
            tau_od[:,iwvl] = tau_tmp * p_prof / np.max(p_prof)
      
        return tau_od

 
    
    
    def set_rh100_limits(self, alt_bounds):

        if not isinstance(alt_bounds,list):
            print("")
            print("(set_h2o_in_cloud) alt_bounds should be a list of altitude boundaries ")
            print("                   e.g. alt_bounds = [ [alt_min1, alt_max_1], [[alt_min2, alt_max_2]] ]")
            print("")
            print("")
            exit()
        
        if ('h2o' in self.gas):

            self.rh100_bounds = []
            
            for ib in range(len(alt_bounds)):

                if not isinstance(alt_bounds[ib],list):
                    print("")
                    print("(set_h2o_in_cloud) alt_bounds should be a list of altitude boundaries ")
                    print("                   e.g. alt_bounds = [ [alt_min1, alt_max_1], [[alt_min2, alt_max_2]] ]")
                    print("")
                    print("")
                    exit()

                self.rh100_bounds.append(np.sort( np.array(alt_bounds[ib]) ))
                
        else:
            
            print("")
            print("(set_h2o_in_cloud) No water vapour in the atmosphere ")
            print("")


        return




    


########################################################################################

class geometry(object):
    '''
    '''
    def __init__(self, sza, vza, vaa):

        # geometry
        self.sza = sza
        if not np.array_equal(vza, np.sort(vza)[::-1]):
            print("")
            print("(geometry) : vza must be sorted in decreasing order")
            print("")
            exit()
        self.vza = vza

        if not np.array_equal(vaa, np.sort(vaa)):
            print("")
            print("(geometry) : vaa must be sorted in increasing order")
            print("")
            exit()
        self.vaa = vaa

        self.lon_lat_day_time = 'none'


########################################################################################
        
class artdeco_in(object):
    '''
    '''
    def __init__(self, 
                keywords=["none"], \
                mode="kdis toto",       \
                filters=["none"],  \
                wavel=[-1,-1],     \
                trunc_method="dm", \
                nstreams=8,       \
                rt_model="doad", \
                corint=True,     \
                thermal=False,   \
                nmat=1,            \
                od_accur = 0.0,             \
                doad_eps            = 1e-5, \
                doad_nfoumx         = 1000, \
                doad_ncoef_min_brdf = 10,   \
                dfit_thetac = 0.0,          \
                dfit_fitall = False,        \
                potter_theta_min = 10.,     \
                potter_theta_max = 11.,     \
                corint_brdf = -1,           \
                od_nstr = -1,               \
                od_deltam = False,          \
                od_corint = False,          \
                od_do_secsca = True,        \
                od_use_fisot = False):
  

        #print "Set artdeco_in class..."

        # keywords
        if 'none' in keywords:
            self.keywords = ['none']
        else:
            self.keywords = keywords                
        # mode
        if mode.split()[0] not in ['mono', 'kdis', 'lbl']:
            print("(artdeco_in) ERROR")
            print("             mode must be -mono-, -lbl- or -kdis-")
            exit()    
        self.mode     = mode.split()[0]
        if self.mode == 'kdis':
            self.kdis_model = mode.split()[1]  
        # filters    
        if 'none' in filters:
            self.filters  = 'none'
        else:
            print("(artdeco_in) ERROR")
            print("             filter use not implemented")
            exit()
            self.filters  = filters
        # wavelengths    
        if (self.mode == 'kdis') and (len(wavel) not in [1,2]):
            print("(artdeco_in) ERROR")
            print("             In kdis mode, you must provide 1 or 2 wavelenghts")
            exit()
        if self.mode == 'mono':
            if not np.array_equal(wavel, np.sort(wavel)):
                print("(artdeco_in) ERROR")
                print("             In mono mode, wavelengths must be sorted in increasing order")
                exit()
            self.wavel = np.sort(wavel)
        else:
            if not -1.0 in wavel:
                self.wavel = np.sort(wavel)
            else:
                self.wavel = wavel    
             
   
        # truncation
        if trunc_method not in ['none','dm','dfit','potter']:
            print("(artdeco_in) ERROR")
            print("             Trunc_method must be -none-, -dfir-, -potter- or -dm-")
            exit()
        self.trunc_method   = trunc_method
        if trunc_method == 'dfit':
            if dfit_thetac > 180.0 or dfit_thetac < 0.0:
                print("(artdeco_in) ERROR")
                print("             Angular cut for dfit must be >0 and <180")
                exit()
            self.dfit_thetac = dfit_thetac
            self.dfit_fitall = dfit_fitall       
        if trunc_method == 'potter':
            if (potter_theta_min  > 180.0) or (potter_theta_min  < 0.0) or (potter_theta_max  > 180.0) or (potter_theta_max  < 0.0) :
                print("(artdeco_in) ERROR")
                print("             theta_min and theta_max for Potter truncation must be >0 and <180")
                exit()
            if potter_theta_min  >= potter_theta_max:
                print("(artdeco_in) ERROR")
                print("             theta_min must be smaller than theta_max for Potter truncation")
                exit()                
            self.potter_theta_min = potter_theta_min
            self.potter_theta_max = potter_theta_max
        # nstreams
        self.nstreams = int(nstreams)
        # rt model
        if rt_model == 'doad':
            self.doad_ncoef_min_brdf = int(doad_ncoef_min_brdf)
            self.doad_eps            = doad_eps
            self.doad_nfoumx         = int(doad_nfoumx)
            self.doad_nmug           = int(self.nstreams/2)
        elif rt_model == 'disort':
            if od_nstr==-1:
                self.od_nstr  = self.nstreams
            else:
                self.od_nstr  = od_nstr

            self.od_deltam = od_deltam
            self.od_corint = od_corint
            self.od_do_secsca = od_do_secsca
            self.od_accur = od_accur
            self.od_use_fisot = od_use_fisot
        elif rt_model == "sinsca":
            # No single scattering correction in single scattering RTE
            if corint:
                corint = False
            if thermal:
                print("(artdeco_in) ERROR")
                print("             Thermal emission is not accounted for ")
                print("             with the -sinsca- RTE solver")
                exit()
            self.nstreams = 0
        else:
            print("(artdeco_in) ERROR")
            print("             rt_model not implemented for f2py:", rt_model)
            exit()
        self.rt_model =  rt_model
        # TMS
        self.corint   = corint
        # Thermal
        if (self.rt_model != 'disort') and (thermal):
            print("(artdeco_in) ERROR")
            print("             To account for thermal IR, rt model must be -disort-")
            exit()            
        self.thermal  = thermal        
        # Polar ?
        if (nmat>1) and ( self.rt_model not in ['doad', 'sinsca']): 
            print("(artdeco_in) ERROR")
            print("             To account for polarization, rt model must be -doad- or -sinsca-")
            exit()            
        self.nmat     = nmat

        if corint_brdf == -1:
            self.corint_brdf = self.corint
        else:
            self.corint_brdf = corint_brdf


########################################################################################
########################################################################################
########################################################################################
########################################################################################

def parseGridSpec (gridSpec):
    """ Set up (altitude) grid specified in format 'start[step1]stop1[step2]stop' or similar. """

    # get indices of left and right brackets
    lp = [];  rp = []
    for i in range(len(gridSpec)):
        if   (gridSpec[i]=='['):  lp.append(i)
        elif (gridSpec[i]==']'):  rp.append(i)
        else:                     pass
    if len(lp) != len(rp):
        print('cannot parse grid specification\nnumber of opening and closing braces differs!\nUse format start[step]stop')
        raise SystemExit

    # parse
    gridStart = [];  gridStop = [];  gridStep = []
    for i in range(len(lp)):
        if i>0:  start=rp[i-1]+1
        else:    start=0
        if i<len(lp)-1: stop=lp[i+1]
        else:           stop=len(gridSpec)

        try:
            gridStart.append(float(gridSpec[start:lp[i]]))
        except ValueError:
            print('cannot parse grid start specification\nstring not a number!')
            raise SystemExit
        try:
            gridStep.append(float(gridSpec[lp[i]+1:rp[i]]))
        except ValueError:
            print('cannot parse grid step specification\nstring not a number!')
            raise SystemExit
        try:
            gridStop.append(float(gridSpec[rp[i]+1:stop]))
        except ValueError:
            print('cannot parse grid stop specification\nstring not a number!')
            raise SystemExit

    # create the new grid (piecewise linspace)
    newGrid = []
    for i in range(len(lp)):
        n = int(round(abs((gridStop[i] - gridStart[i])/gridStep[i])))
        endpoint = (i == len(lp)-1)
        if endpoint: n += 1
        newGrid.extend(list(np.linspace(gridStart[i], gridStop[i], n, endpoint=endpoint)))

    return np.array(newGrid)


if __name__=='__main__':
    print("pyartdeco_runlib is a library")

