
import os
import sys
import numpy as np
import time
import matplotlib
# Force matplotlib to not use any Xwindows backend
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import h5py

########################################

sys.path.append("../f2py_utils")
import pyartdeco_runlib as pyartdeco
from f2py_utils import run_artdeco

dir_artdeco = "/home/mathieu/work/RTM/artdeco_test_th/fortran/"
dir_kdis = "/rfs/proj/pykdis/output/"

########################################

def approx_equal(a, b, tol):
     return abs(a - b) < tol

#########################

def plot_polar_contour(values, azimuths, zeniths, barname='Pixel reflectance', noneg=False, cut=False):
    """Plot a polar contour plot, with 0 degrees at the North.
 
    Arguments:
 
     * `values` -- A list (or other iterable - eg. a NumPy array) of the values to plot on the
     contour plot (the `z` values)
     * `azimuths` -- A list of azimuths (in degrees)
     * `zeniths` -- A list of zeniths (that is, radii)
 
    The shapes of these lists are important, and are designed for a particular
    use case (but should be more generally useful). The values list should be `len(azimuths) * len(zeniths)`
    long with data for the first azimuth for all the zeniths, then the second azimuth for all the zeniths etc.
 
    This is designed to work nicely with data that is produced using a loop as follows:
 
    values = []
    for azimuth in azimuths:
       for zenith in zeniths:
        # Do something and get a result
        values.append(result)
 
    After that code the azimuths, zeniths and values lists will be ready to be passed into this function.
 
    """
    theta = np.radians(azimuths)
    zeniths = np.array(zeniths)
 
    values = np.array(values)
    val    = values.reshape(len(azimuths), len(zeniths))

    cut_min  = 0.5
    cut_max  = 99.5
    ncontour = 50

    if cut:
         minval = np.percentile(val[~np.isnan(val)], cut_min)
         maxval = np.percentile(val[~np.isnan(val)], cut_max)
         if noneg and minval<0:
              minval=0.0
              
         print(minval)
         print(maxval)
         extend = "both"
         # if abs(minval - np.min(val[~np.isnan(val)])) < diff_lim :
         #     extend = "max"
         #     minval = np.min(val[~np.isnan(val)])
         # if abs(maxval - np.max(val[~np.isnan(val)])) < diff_lim :
         #     maxval = np.max(val[~np.isnan(val)])
         #     if extend == "max":
         #         extend = "neither"
         #     else:
         #         extend = "min"
         levels = np.linspace(minval, maxval, num=ncontour, endpoint = True)
    
    r, theta = np.meshgrid(zeniths, np.radians(azimuths))
    fig, ax = plt.subplots(subplot_kw=dict(projection='polar'))
    ax.set_theta_zero_location("E")
    ax.set_theta_direction(-1)
    #autumn()
    if cut:
         cax = ax.contourf(theta, r, val, levels, extend=extend)
    else:
         cax = ax.contourf(theta, r, val, ncontour)
    #autumn()
    cb = fig.colorbar(cax)
    cb.set_label(barname)
 
    return fig, ax, cax

##############################################################



     


def polar_rad(nmat, vza, vaa, rad, root_name, dir_save, cut=False, delta=False):

     plt.set_cmap('rainbow')

     if nmat>1:
          I = rad[:,:,0]
     else:
          I = rad[:,:]
          
     I = np.transpose(I)

     if delta:
          tt = "$\delta$"
     else:
          tt = ""
          
     plt.clf()
     if cut:
          noneg=True
     else:
          noneg=False
     fig, ax, cax = plot_polar_contour(I, vaa, vza, barname=tt+'I', cut=cut, noneg=noneg)        
     plt.savefig(dir_save+'radiance_I_'+root_name+'.png')

     if nmat>1:
          
          Q = rad[:,:,1]
          Q = np.transpose(Q)

          U = rad[:,:,2]
          U = np.transpose(U)

          plt.clf()
          
          fig, ax, cax = plot_polar_contour(Q, vaa, vza, barname=tt+'Q', cut=cut)
          plt.savefig(dir_save+'radiance_Q_'+root_name+'.png')

          plt.clf()
          fig, ax, cax = plot_polar_contour(U, vaa, vza, barname=tt+'U', cut=cut)
          plt.savefig(dir_save+'radiance_U_'+root_name+'.png')

          plt.clf()
          fig, ax, cax = plot_polar_contour(np.sqrt(U**2.0+Q**2.0), vaa, vza, barname=tt+'$I_{pol}$')
          plt.savefig(dir_save+'radiance_pol_'+root_name+'.png')
            
     if nmat == 4:
          
          V = rad[:,:,3]
          V = np.transpose(V)

          plt.clf()
          fig, ax, cax = plot_polar_contour(V, vaa, vza, barname=tt+'V', cut=cut)
          plt.savefig(dir_save+'radiance_V_'+root_name+'.png')
          













def run_rad(rt_model, thermal,  kdis_dir, kdis_model, ich, nmat, sza, nstr_in, nfou_in, ptcle_def,  taur, ocean={}, land={}, corint_in=False, corint_brdf_in=False, split_za=True):

     ###########################################
     # artdeco_in technical parameters structure
     
     mode     = 'kdis '+kdis_model
          
     keywords = ['nowarn', 'od_no_check']
     wavel    = [-1,-1]          
     
     filters  = ['none']    

     if corint_brdf_in and not corint_in:
          print(" If corint_brdf==True then you should have corint=True")
          exit()

     corint      = corint_in
     corint_brdf = corint_brdf_in

     trunc_method  = 'dm'
     # for the entire hemisphere
     nstreams_init = nstr_in
     
     doad_ncoef_min_brdf = nfou_in
     
     artdeco_in = pyartdeco.artdeco_in(keywords, mode, filters, wavel,                    \
                                       trunc_method, nstreams_init, rt_model,             \
                                       corint, thermal, nmat, doad_ncoef_min_brdf=doad_ncoef_min_brdf, \
                                       corint_brdf=corint_brdf)

     ####################
     # read kdis coeff
     kdis   = pyartdeco.kdis_coeff(artdeco_in, kdis_dir, 'h5',channel_list=np.array([ich]))

     ######################
     # load solar TOA flux
     solrad_path=kdis_dir+"kdis_"+kdis_model+".h5"
     solrad = pyartdeco.solar_irradiance(artdeco_in, solrad_path, channel_list=np.array([ich]), file_format="h5_kdis")
     #solrad.kdis_solrad_F0[:] = 100.0


     ####################
     #  load atmosphere
     #gas  = ['o3', "o2",   "h2o"]
     #ppmv = [-1,  2.095e+05, -1]
     gas  = [ "o2",   ]
     ppmv = [ 2.095e+01]
     atm  = 'atm_afgl_midsum_red.dat'
     dir_data = "/rfs/proj/artdeco_lib/atm/"
     wldepol  = np.array([0.1, 200.0]) # microns
     depol    = np.array([0.0279,   0.0279])
     
     atmos    = pyartdeco.atmosphere(artdeco_in, dir_data, atm, gas, ppmv, 'ascii', wldepol, depol, interp=True, kdis=kdis, tau_ray = { "wvl":np.array([0.1,200.]), "tau":np.array([taur,taur]) })


     ###############################
     #   set ptcle structure
     t0 = time.time()

     wlref      = 0.55
     wlptcle    = [np.min(kdis.wvlband), np.max(kdis.wvlband)]
     ptcle_opt = pyartdeco.ptcle_optical_properties(ptcle_def, artdeco_in.nstreams, wlptcle,  wlref, opt_interp=False)
     ptcle     = pyartdeco.particle(artdeco_in, ptcle_def, ptcle_opt)    


     ########################
     # set surface structure

     if len(ocean.keys()) !=0:
          name      = "ocean"
          family    = "brdf"
          kind      = "ocean"
          surface   = pyartdeco.surface(name, family, kind, wdspd=ocean['wdspd'], temp=ocean['temp'], 
                                        sdistr=ocean["sdistr"], azw=ocean["azw"], shadow=ocean["shadow"], _6s_glitter=ocean["_6s_glitter"])  
     elif len(land.keys()) !=0:
          if "family" in list(land.keys()):
               if land["family"]=="brdf":
                    name      = "land"
                    family    = "brdf"
                    kind      = "land"
                    if "nr" in land.keys():
                         surface   = pyartdeco.surface(name, family, kind,
                                                       bpdf = True,
                                                       interp = True,       
                                                       wl=land['wvl'],  iso=land['iso'], vol=land['vol'], geo=land['geo'], cv=land['cv'], nr=land['nr'], ni=land['ni'], temp=land['temp'])
                    else:
                         surface   = pyartdeco.surface(name, family, kind,
                                                       bpdf = False,
                                                       interp = True,       
                                                       wl=land['wvl'],  iso=land['iso'], vol=land['vol'], geo=land['geo'], temp=land["temp"])              
          else:
               name      = "land"
               family    = "lambert"
               kind      = "land"
               surface   = pyartdeco.surface(name, family, kind,
                                                  interp = True,       
                                                  wl=land['wvl'],  alb=land['alb'], temp=land["temp"])


     ########################
     #      Geometry
     
     vza = np.linspace(0, 89.99, num=52)
     vza = vza[::-1]
     vaa =  np.linspace(0, 360, num=360)
          

     ###############
     # run ARTDECO
     
     rad_out  = np.full((len(sza), len(vza), len(vaa), artdeco_in.nmat), np.nan)

     #(artdeco.mcommon.nsza, artdeco.mcommon.nvza, artdeco.mcommon.nvaa, artdeco.mcommon.nmat, artdeco.mcommon.nlambda)
     if split_za:

          t0 = time.time()
          for isza in range(len(sza)):
               for ivza in range(len(vza)):
                    geom = pyartdeco.geometry(np.array([sza[isza]]), np.array([vza[ivza]]), vaa)
                    lamb, rad, rad_levels, flux, alt= run_artdeco(artdeco_in, atmos, surface, solrad, ptcle, geom, kdis=kdis, verbose=False)
                    rad_out[isza,ivza,:,:] = rad[0,0,:,:,0]
          t1 = time.time() - t0

     else:
          
          geom = pyartdeco.geometry(sza, vza, vaa)
          t0 = time.time()
          lamb, rad, rad_levels, flux, alt= run_artdeco(artdeco_in, atmos, surface, solrad, ptcle, geom, kdis=kdis, verbose=False)
          rad_out[:,:,:,:] = rad[:,:,:,:,0]
          t1 = time.time() - t0

     if rad.shape[4] > 1:
          print(" WVL Pb")
          exit()

     print("tps = ", t1)
         
     rad_out = rad_out # np.pi / np.cos(geom.sza[0] * np.pi / 180.)

     return sza, vza, vaa, rad_out



     















if __name__=='__main__':

          
     nfou  = 20
     sza   = np.array([30.0])
     nstr  = 8
     nmat  = 3

     thermal     = False
     rt_model    = 'doad'
     split_za    = False


     kdis_model   = 'parasol'
     kdis_dir = dir_kdis+"/"+kdis_model+"/"
     print(dir_kdis+"/"+kdis_model+"/kdis_"+kdis_model+".h5")
     ich = 1          
     if kdis_model == "none":
          opt_path = "/rfs/proj/artdeco_lib/opt/opt_opac.h5"
     else:
          opt_path = "/rfs/proj/artdeco_lib/opt/opt_opac_"+kdis_model+".h5"

     ptcle_def = [ ]
     taur = 1.0e-20

     shadow = False
     wdspd  = 3.0
 
     corint      = True
     corint_brdf = True



     ocean = {"wdspd":wdspd, "temp":285.0,"sdistr":3, "azw":45.0, "shadow":shadow, "_6s_glitter":True}
     land={}          
     #land = {"wvl":np.array([0.1,100.]), "alb":np.array([0.1,0.1]), "temp":285.0}
     #land = {"wvl":np.array([0.1,100.]), "family":"brdf", "iso":np.array([0.1,0.1]),"vol":np.array([0.1,0.1]),"geo":np.array([1.0,1.0]), "temp":285.0}
     #ocean={}
     sza, vza, vaa, rad1 = run_rad(rt_model, thermal, kdis_dir, kdis_model, ich, nmat, sza, nstr, nfou, ptcle_def, taur, 
     corint_in=corint, corint_brdf_in=corint_brdf, ocean=ocean, land=land, split_za=split_za)

     ocean = {"wdspd":wdspd, "temp":285.0,"sdistr":3, "azw":45.0, "shadow":shadow, "_6s_glitter":False}
     land={}     
     sza, vza, vaa, rad2 = run_rad(rt_model, thermal, kdis_dir, kdis_model, ich, nmat, sza, nstr, nfou, ptcle_def, taur, 
     corint_in=corint, corint_brdf_in=corint_brdf, ocean=ocean, land=land, split_za=split_za)

     for isza in range(len(sza)):
          polar_rad(nmat, vza, vaa, np.squeeze(rad1[isza,:,:,:]), "6S_sza%.2f"%sza[isza], "./", cut=True)
          polar_rad(nmat, vza, vaa, np.squeeze(rad2[isza,:,:,:]), "Mishchenko_sza%.2f"%sza[isza], "./", cut=True)
          polar_rad(nmat, vza, vaa, np.squeeze(rad1[isza,:,:,:]-rad2[isza,:,:,:]), "diff_6S-Mishchenko_sza%.2f"%sza[isza], "./", cut=True)




     ocean = {"wdspd":wdspd, "temp":285.0,"sdistr":2, "azw":45.0, "shadow":shadow, "_6s_glitter":True}
     land={}          
     #land = {"wvl":np.array([0.1,100.]), "alb":np.array([0.1,0.1]), "temp":285.0}
     #land = {"wvl":np.array([0.1,100.]), "family":"brdf", "iso":np.array([0.1,0.1]),"vol":np.array([0.1,0.1]),"geo":np.array([1.0,1.0]), "temp":285.0}
     #ocean={}
     sza, vza, vaa, rad1 = run_rad(rt_model, thermal, kdis_dir, kdis_model, ich, nmat, sza, nstr, nfou, ptcle_def, taur, 
     corint_in=corint, corint_brdf_in=corint_brdf, ocean=ocean, land=land, split_za=split_za)


     ocean = {"wdspd":wdspd, "temp":285.0,"sdistr":1, "azw":45.0, "shadow":shadow, "_6s_glitter":True}
     land={}     
     sza, vza, vaa, rad2 = run_rad(rt_model, thermal, kdis_dir, kdis_model, ich, nmat, sza, nstr, nfou, ptcle_def, taur, 
     corint_in=corint, corint_brdf_in=corint_brdf, ocean=ocean, land=land, split_za=split_za)

     for isza in range(len(sza)):
          polar_rad(nmat, vza, vaa, np.squeeze(rad1[isza,:,:,:]), "sdistr2_sza%.2f"%sza[isza], "./", cut=True)
          polar_rad(nmat, vza, vaa, np.squeeze(rad2[isza,:,:,:]), "sdistr1_sza%.2f"%sza[isza], "./", cut=True)
          polar_rad(nmat, vza, vaa, np.squeeze(rad1[isza,:,:,:]-rad2[isza,:,:,:]), "diff_sdistr2-sdistr1_sza%.2f"%sza[isza], "./", cut=True)

