#!/usr/bin/env python
from __future__ import print_function, division, absolute_import
from builtins import range

from netCDF4 import Dataset
from numpy import ma, NaN
from os import remove
from os.path import exists

def read_nc(filename, variable):

    ncId = Dataset(filename, 'r', format='NETCDF4')

    if not variable in ncId.variables:
        print('Can not find variable "%s" in "%s"' % (variable, filename))
        print('Available variables:')
        for i in ncId.variables:
            print(' ->', i, '-', ncId.variables[i].shape, ncId.variables[i].dtype)
        return None

    var = ncId.variables[variable]
    if ('scaling equation' in dir(var)) and (getattr(var, 'scaling equation') == 'value = slope*data + intercept'):
        slope = getattr(var, 'slope')
        intercept = getattr(var, 'intercept')
        dat = slope * var[:] + intercept
    elif ('scaling equation' in dir(var)) and (getattr(var, 'scaling equation') == 'value = 10^(slope*data + intercept)'):
        slope = getattr(var, 'slope')
        intercept = getattr(var, 'intercept')
        dat = 10**(slope * var[:] + intercept)
    else:
        dat = var[:]

    ncId.close()

    if isinstance(dat, ma.masked_array):
        arr = ma.getdata(dat)
        if dat.dtype in ['float32', 'float64']:
            arr[dat.mask] = NaN

        return arr
    else:
        return dat



def nc_info(filename):

    ncId = Dataset(filename, 'r', format='NETCDF4')

    for var in ncId.variables:
        V = ncId.variables[var]
        print(var, V.shape, V.dtype)

    ncId.close()



class NCCreator(object):

    def __init__(self, filename, compress=True, overwrite=False):
        if exists(filename):
            if overwrite:
                print('removing', filename)
                remove(filename)
            else:
                raise Exception('File {} exists'.format(filename))
        self.__rootgrp = Dataset(filename, 'w', format='NETCDF4')
        self.__compress = compress

    def write(self, name, data, dimensions):
        assert data.ndim == len(dimensions)

        # initialize dimensions
        for idim in range(len(dimensions)):
            dim = dimensions[idim]
            size = data.shape[idim]
            if not dim in self.__rootgrp.dimensions:
                print('creating dimension {} of size {}'.format(dim, size))
                self.__rootgrp.createDimension(dim, size)

        var = self.__rootgrp.createVariable(name, data.dtype, dimensions, zlib=self.__compress)
        var[:] = data[:]

    def __del__(self):
        self.close()

    def close(self):
        if self.__rootgrp != None:
            self.__rootgrp.close()
            self.__rootgrp = None

