"""
/***************************************************************************
Name	   :  geoclim_gdal_utilities.py
Description:  GeoCLIM gdal type Utility Functions for FEWSTools plugin,
              split out from other utility files
copyright  :  (C) 2022 - 2023 by FEWS
email      :  minxuansun@contractor.usgs.gov
Created    :  01/05/2022 - cholen
Modified   :  03/16/2022 - cholen - Update get_data_cube, handle np invalids
              03/21/2022 - cholen - Add tiff support

 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
import os
import subprocess
import numpy as np

try:
    from osgeo import gdal
    from osgeo import gdal_array

except ImportError:
    import gdal
    import gdal_array


from qgis.core import QgsMessageLog, Qgis

from fews_tools import fews_tools_config as config


gdal.UseExceptions()


BAND_NUM = 1
GDAL_INT16 = "Int16"
GDAL_BYTE = "Byte"

TYPE_DIC = {
    "Byte": {"GDAL": gdal.GDT_Byte, "NP":np.uint8},
    "Int16": {"GDAL": gdal.GDT_Int16, "NP":np.int16},
    "Uint16": {"GDAL": gdal.GDT_UInt16, "NP":np.uint16},
    "Int32": {"GDAL": gdal.GDT_Int32, "NP":np.int32},
    "Uint32": {"GDAL": gdal.GDT_UInt32, "NP":np.uint32},
    "Float32": {"GDAL": gdal.GDT_Float32, "NP":np.float32},
    "Float64": {"GDAL": gdal.GDT_Float64, "NP":np.float64}}


def calc_masked_array_diff(array_1, array_2, mask_array, nd_val=-9999):
    """
    Calculate the difference(array_1 - array_2) of two numpy arrays and
    apply the mask.

    Arguments:
    array_1 -- 2d numpy geotiff array.
    array_2 -- 2d numpy geotiff array.
    mask_array -- 2d numpy mask array.
    nd_val -- Integer no data value.
    Returns:
    masked_diff_array -- 2d numpy masked difference array.
    """
    masked_diff_array = None
    if mask_array.shape != array_1.shape != array_2.shape:
        raise IOError("Mask and/or array shape doesn't match")
    diff_array = array_1 - array_2
    masked_diff_array = np.ma.masked_array(
         diff_array, mask=mask_array, fill_value = nd_val).filled()
    return masked_diff_array


def calc_masked_data_cube_sum(data_cube, mask_array, nd_val=-9999):
    """
    Calculate the numpy sum of a data_cube along the axis=0 (date/time), then
    apply the mask.
    The data_cube is a 3d array of geotiffs, where axis 0 is date/time,
    axis 1 is latitude(rows), and axis 2 is longitude(cols).

    Arguments:
    data_cube -- 3d numpy array of geotiffs.
    mask_array -- 2d numpy mask array(matches the shape of data_cube[1:]).
    nd_val -- Integer no data value.
    Returns:
    masked sum array -- 2d numpy masked sum array.
    """
    masked_sum_array = None
    if mask_array.shape != data_cube.shape[1:]:
        raise IOError("Summing - Mask shape doesn't match data cube")
    sum_array = np.sum(data_cube, axis=0)
    masked_sum_array = np.ma.masked_array(
         sum_array, mask=mask_array, fill_value = nd_val).filled()
    return masked_sum_array


def calc_masked_data_cube_np_stat(
        data_cube, mask_array, stat="Average", nd_val=-9999, percentile=None):
    """
    Function to calculate the statitics of a group of files. Statistics
    can be generated for average, median, standard deviation, variance and
    percentile on a per pixel basis.

    Arguments:
    data_cube -- 3d numpy array of geotiffs.
    mask_array -- 2d numpy mask array(matches the shape of data_cube[1:]).
    stat -- Statistic to calculate.
    nd_val -- Integer no data value.
    pctl -- Integer percentile value.
    Returns:
    masked_stat_array -- 2d numpy masked stat array.
    """
    masked_stat_array = None
    try:
        # np.xxx  stats calls return float type array
        if stat == "Average":
            stat_array = np.average(data_cube, axis=0)
        elif stat == "Median":
            stat_array = np.median(data_cube, axis=0)
        elif stat == "Standard Deviation":
            stat_array = np.std(data_cube, axis=0)
        elif stat == "Variance":
            stat_array = np.var(data_cube, axis=0)
        elif stat == "Percentiles":
            stat_array = np.percentile(data_cube, percentile, axis=0)
        masked_stat_array = np.ma.masked_array(
            stat_array, mask=mask_array, fill_value = float(nd_val)).filled()
    except BaseException:
        QgsMessageLog.logMessage(
            "Exception - Unspecified error in statistics calculation",
            level=Qgis.Critical)
    return masked_stat_array


def calc_masked_pct_of_avg(data_array, avg_array, mask_array, nd_val=-9999):
    """
    Calculate the numpy percent of average array, then
    apply the mask.  Formula is result = data * 100 / avg

    Arguments:
    data_array -- 2d numpy geotiff array.
    avg_array -- 2d numpy geotiff array.
    mask_array -- 2d numpy mask array.
    nd_val -- Integer no data value.
    Returns:
    masked_pct_avg_array -- 2d numpy masked percent of average array.
    """
    masked_pct_avg_array = None
    if mask_array.shape != data_array.shape != avg_array.shape:
        raise IOError("Mask and/or array shape doesn't match")
    # do the math as float and then round, handle divide by zero and invalid
    with np.errstate(invalid='ignore', divide='ignore'):
        pct_avg_array = ((data_array * 100.0) / avg_array).round()
        masked_pct_avg_array = np.ma.masked_array(
             pct_avg_array, mask=mask_array, fill_value = nd_val).filled()
        # cap values at 10000
        masked_pct_avg_array[masked_pct_avg_array > 10000] = 10000
    return masked_pct_avg_array


def clip_raster_to_region(reg_dic, src_filename, dst_filename, fudge=0):
    """
    Clip a raster to region extents, this does not assign any nodata,
    result will be whatever the source values are for the pixels.
    fudge is used when clipping mask so we can get a mask that is slightly
    larger than the region.

    Arguments:
    reg_dic -- Region parameters.
    src_filename -- Input raster filename.
    dst_filename -- Output raster filename.
    fudge -- Amount to fudge.
    """
    o_fmt = get_gdal_driver_name(src_filename)
    cmd = ("gdal_translate -projwin " +
           str(reg_dic["MinimumLongitude"] - fudge) + " " +
           str(reg_dic["MaximumLatitude"] + fudge) + " " +
           str(reg_dic["MaximumLongitude"] + fudge) + " " +
           str(reg_dic["MinimumLatitude"] - fudge) + " " +
           " -of " + o_fmt + " " + handle_subprocess_call_path(src_filename)
           + " " + handle_subprocess_call_path(dst_filename))
    subprocess.call(cmd, shell=True)


def clip_raster_to_bbox(src_filename, dst_filename, bbox, nd_val, fudge=0):
    """
    Clip a raster to a bounding box, with a no data value.
    fudge is used when clipping mask so we can get a mask that is slightly
    larger than the region.

    Arguments:
    src_filename -- Input raster filename.
    dst_filename -- Output raster filename.
    bbox -- Bounding box dictionary
    nd_val -- No data value(as string)
    fudge -- Amount to fudge.
    """
    o_fmt = get_gdal_driver_name(src_filename)
    # no commas separate these!!
    extents_str = "{} {} {} {}".format(
        str(bbox["MinLongitude"] - fudge),
        str(bbox["MaxLatitude"] + fudge),
        str(bbox["MaxLongitude"] + fudge),
        str(bbox["MinLatitude"] - fudge))
    cmd = ("gdal_translate -a_nodata " + nd_val +
           " -projwin " +
           extents_str + " -of " + o_fmt + " " +
           handle_subprocess_call_path(src_filename) + " " +
           handle_subprocess_call_path(dst_filename))
    subprocess.call(cmd, shell=True)


def extract_raster_array(src_filename, d_type=np.int16):
    """
    Extract np array from file.

    Arguments:
    src_filename -- Raster data filename.
    d_type -- Numpy data type.
    Returns:
    src_array - 2d numpy raster array.
    """
    try:
        if not os.path.exists(src_filename):
            raise IOError(src_filename + " does not exist!!")
        # get driver to work with src filename type
        drvr = get_gdal_driver_object(src_filename)
        drvr.Register()

        # open file, to get geo_xform
        src_ds = gdal.Open(src_filename, gdal.GA_ReadOnly)
        bands = src_ds.RasterCount  # generally only one band, but just in case
        src_rst_bands = src_ds.GetRasterBand(bands)
        src_array = src_rst_bands.ReadAsArray().astype(d_type)
        src_ds = None
        del src_ds
    except BaseException:
        QgsMessageLog.logMessage(
            "Exception - Unspecified error when extracting raster array",
            level=Qgis.Critical)
    return src_array


def get_data_cube(filenames_list, d_type=np.int16):
    """
    Construct a data cube from a list of geotiffs for use in numpy
    operations.  It assumes that the listed geotiffs have been
    resampled/clipped so they all match and are the desired output cell size
    and extents.  It can run with a single element filenames list.

    Arguments:
    filenames_list -- List of the raster filenames.
    d_type -- Numpy data type.
    Returns:
    data_cube - 3d numpy raster array(axis 0 = time,
                                      axis 1 = latitude,
                                      axis 2 = longitude)
    """
    data_cube = None
    try:
        # get the input rasters into an array
        idx = 0
        for entry in filenames_list:
            if not os.path.exists(entry):
                idx = filenames_list.index(entry)
                raise IOError
            src_rst_array = extract_raster_array(entry, d_type)
            if data_cube is not None:
                data_cube = np.concatenate((data_cube, src_rst_array), axis=0)
            else:
                data_cube = src_rst_array
                # save the rows and cols here for reshape
                row_ct = src_rst_array.shape[0]
                col_ct = src_rst_array.shape[1]
        data_cube = data_cube.reshape(len(filenames_list), row_ct, col_ct)
    except IOError:
        QgsMessageLog.logMessage("Missing file: " + str(filenames_list[idx]),
                                 level=Qgis.Critical)
    except BaseException:
        QgsMessageLog.logMessage(
            "Exception - Unspecified error getting data cube",
            level=Qgis.Critical)
    src_rst_array = None
    return data_cube


def get_data_cube_mask(src_data_cube, nd_val):
    """
    Function to get a numpy mask from a data_cube(3d array of geotiffs, where
    axis 0 is date/time, axis 1 is latitude(rows), and axis 2 is longitude(cols),
    For each row, column, if the nd_value is contained in any geotiff, the mask
    is set to 1, else it is 0. NOTE: this is opposite of the mask files for
    GeoCLIM!!

    Arguments:
    src_data_cube -- 3d numpy array of raster data.
    nd_val -- No data value
    Returns:
    mask_array -- 2d numpy mask array.
    """
    if len(src_data_cube.shape) != 3:
        raise IOError("Data cube dimensions are incorrect")
    mask_array = np.apply_along_axis(mask_function, 0, src_data_cube, nd_val)
    return  mask_array


def get_data_cube_and_region_mask(
    src_data_cube, mask_filename, nd_val, invert_reg_mask=True):
    """
    Construct a mask that combines the region mask with data cube mask.

    Arguments:
    src_data_cube -- 3d numpy array of raster data, should match the size of
                     region mask array, messages if not.
    mask_filename -- Absolute region mask filename.
    nd_val -- No data value.
    invert_reg_mask -- Boolean to flag inversion, always true at this point.
    Returns:
    final_mask -- 2d numpy mask, same size as region mask.
    """
    dc_mask_array = get_data_cube_mask(src_data_cube, nd_val)
    # currently, for np masking we need to invert the fews_tools reg_mask_array
    # IF they start making masks where that is switched, additional work needs
    # to be done so users can specify in the region definition which value
    # is the masked value.
    if invert_reg_mask is True:
        reg_mask_array = get_inverted_mask_array(mask_filename)# np.logical_not(reg_mask_array)
    else:
        reg_mask_array = extract_raster_array(mask_filename)

    if dc_mask_array.shape != reg_mask_array.shape:
        raise IOError("Mismatched region and data cube mask shapes!")
    final_mask = np.logical_or(dc_mask_array, reg_mask_array)
    dc_mask_array = None
    reg_mask_array = None
    return final_mask


def get_gdal_driver_name(src_filename):
    """
    Get the gdal driver name based on the filename. Used for gdal subprocess
    calls.
    Arguments:
    src_filename -- Input raster filename.
    d_obj - boolean, true if want the driver as an object, otherwise string
    Returns:
    drv_txt - GDAL driver name
    """
    drv_txt = config.EHDR  # the original format
    ext_l = os.path.splitext(src_filename)[1]
    if ext_l in [config.TIFF_SUFFIX, config.TIF_SUFFIX]:
        drv_txt = config.GTIFF
    return drv_txt


def get_gdal_driver_object(src_filename):
    """
    Get the gdal driver name based on the filename. Used for gdal subprocess
    calls.
    Arguments:
    src_filename -- Input raster filename.
    Returns:
    drv - GDAL driver object
    """
    drv_txt = get_gdal_driver_name(src_filename)
    drv = gdal.GetDriverByName(drv_txt)
    return drv


def get_geotiff_info(src_filename):
    """
    Extract parameters from geotiff file.

    Arguments:
    src_filename -- Input raster filename.
    Returns:
    tuple - nd_val, row_ct, col_ct, geoxfrm, data_type
    """
    if not os.path.exists(src_filename):
        raise IOError(u"Missing file: "  + src_filename)
    src_ds = gdal.Open(src_filename, gdal.GA_ReadOnly)
    if src_ds is None:
        raise Exception(u"Corrupt file: "  + src_filename)
    nd_val = src_ds.GetRasterBand(BAND_NUM).GetNoDataValue()
    col_ct = src_ds.RasterXSize
    row_ct = src_ds.RasterYSize
    geoxfrm = src_ds.GetGeoTransform()
    data_type = src_ds.GetRasterBand(BAND_NUM).DataType
    data_type = gdal.GetDataTypeName(data_type)
    return nd_val, row_ct, col_ct, geoxfrm, data_type


def get_inverted_mask_array(mask_filename):
    """
    Invert the array from a mask file for use with numpy operations.
    Currently, for np masking we need to invert the fews_tools reg_mask_array
    IF they start making mask files where that is switched, additional work
    needs to be done so users can specify in the region definition which value
    is the masked value.

    Arguments:
    mask_filename -- Mask file path.
    Returns:
    inv_mask_array -- 2d np mask array
    """
    mask_array = extract_raster_array(mask_filename)
    inv_mask_array = np.logical_not(mask_array)
    return inv_mask_array


def handle_subprocess_call_path(file_path):
    """
    If there is an absolute path in the subprocess call, apply double quotes,
    otherwise the command will fail if any spaces are involved(ie in Windows).

    Arguments:
    file_path -- File path string.
    Returns:
    The quoted file path
    """
    return '"' + file_path + '"'


def mask_function(src_array, nd_val):
    """
    Gets the mask of an array, 0 if nd is not in array, else 1.

    Arguments:
    src_array -- 1d slice of a 3d numpy data cube(axis=0 in this case)
    nd_val -- Nodata value
    Returns:
    mask -- 0 if not masked, else 1
    """
    mask = (((nd_val not in src_array) * 0) + ((nd_val in src_array) * 1))
    return mask


def rasterize_vector(src_filename, dst_filename,
                     cell_size, vector_extent,
                     init_val, burn_val, fudge,
                     d_type="UInt16"):
    """
    Rasterize a vector
    """
    o_fmt = get_gdal_driver_name(dst_filename)
    lyr_name = os.path.split(src_filename)[1].split(".")[0]

    # no commas separate these!!
    extents_str = "{} {} {} {}".format(
        str(vector_extent.xMinimum() - fudge),
        str(vector_extent.yMinimum() - fudge),
        str(vector_extent.xMaximum() + fudge),
        str(vector_extent.yMaximum() + fudge))

    cmd = ("gdal_rasterize -l <layername> -burn <burn> " +
            "-tr <cellsize>  <cellsize> -init <init> " +
            "-te <extents> -ot <dtype> -of <o_fmt> " +
            "<shapefile> <raster>")
    cmd = cmd.replace("<layername>", lyr_name)
    cmd = cmd.replace("<cellsize>", str(cell_size))
    cmd = cmd.replace(
        "<shapefile>", handle_subprocess_call_path(src_filename))
    cmd = cmd.replace(
        "<raster>", handle_subprocess_call_path(dst_filename))
    cmd = cmd.replace("<extents>", extents_str)
    cmd = cmd.replace("<burn>", str(burn_val))
    cmd = cmd.replace("<init>", str(init_val))
    cmd = cmd.replace("<dtype>", d_type)
    cmd = cmd.replace("<o_fmt>", o_fmt)

    subprocess.call(cmd, shell=True)


def rasterize_vector_by_attribute(src_filename, dst_filename,
                     cell_size, region_extent,
                     init_val, attribute, fudge,
                     d_type="UInt16"):
    """
    Rasterize a vector based on an attribute
    """
    o_fmt = get_gdal_driver_name(dst_filename)
    lyr_name = os.path.split(src_filename)[1].split(".")[0]

    # no commas separate these!!
    extents_str = "{} {} {} {}".format(
        str(region_extent.xMinimum() - fudge),
        str(region_extent.yMinimum() - fudge),
        str(region_extent.xMaximum() + fudge),
        str(region_extent.yMaximum() + fudge))

    cmd = ("gdal_rasterize -l <layername> -a <attribute_name> " +
            "-tr <cellsize>  <cellsize> -init <init> " +
            "-te <extents> -ot <dtype> -of <o_fmt> " +
            "<shapefile> <raster>")
    cmd = cmd.replace("<layername>", lyr_name)
    cmd = cmd.replace("<cellsize>", str(cell_size))
    cmd = cmd.replace(
        "<shapefile>", handle_subprocess_call_path(src_filename))
    cmd = cmd.replace(
        "<raster>", handle_subprocess_call_path(dst_filename))
    cmd = cmd.replace("<extents>", extents_str)
    cmd = cmd.replace("<attribute_name>", attribute)
    cmd = cmd.replace("<init>", str(init_val))
    cmd = cmd.replace("<dtype>", d_type)
    cmd = cmd.replace("<o_fmt>", o_fmt)

    subprocess.call(cmd, shell=True)


def replace_value(geoxform, reclassify_list, src_file, dst_file, data_type_name):
    """
    Given a interval, replace the pixel values that fall in the interval to a
    new value, and write them to a new file.

    Arguments:
    geoxform -- Dst file geoxform
    reclassify_list -- A 2D array. Which looks like:
                    [[old_value_from, old_value_to, new_value],
                     [old_value_from, old_value_to, new_value],
                     [old_value_from, old_value_to, new_value]]
    src_file -- Input raster filename.
    dst_file -- Output filename.
    data_type -- data_type_name of original file, e.g. 'Float32'
    Returns:
    err - True indicates an error occurred, else False.
    """
    err = False
    try:
        input_datatype_gdal = gdal.GetDataTypeByName(data_type_name)
        src_rst_array = extract_raster_array(src_file, gdal_array.GDALTypeCodeToNumericTypeCode(input_datatype_gdal))
        for interval in reclassify_list:
            old_value_from = interval[0]
            old_value_to = interval[1]
            new_value = interval[2]

            src_rst_array[
                (src_rst_array > old_value_from) &
                (src_rst_array <= old_value_to)
            ] = new_value
        row_count, col_count = src_rst_array.shape
        write_file(dst_file, src_rst_array, col_count, row_count,
                   geoxform, data_type=input_datatype_gdal)
    except BaseException:
        err = True
    src_rst_array = None
    return err


def resample_input_files(dst_filename, mask_filename, src_files_list, reg_dic):
    """
    Resample mask file and data files based on region extents and pixel values.
    The resampled files store in a temp folder under the output folder.

    Arguments:
    dst_filename -- Output path.
    mask_filename -- Mask filename.
    src_files_list -- Input file list.
    reg_dic -- Region parameters.
    Returns:
    dst_mask_file -- Resampled mask filename.
    dst_files_list -- List of resampled data files.
    """
    # create temp folder under output folder
    temp_folder_path = os.path.join(dst_filename, "temp")
    if not os.path.exists(temp_folder_path):
        os.makedirs(temp_folder_path)
    # resample mask file
    if mask_filename:
        dst_mask_file = os.path.join(
            temp_folder_path, os.path.basename(mask_filename))
        if os.path.exists(dst_mask_file):
            os.remove(dst_mask_file)
        resample_raster(reg_dic, mask_filename, dst_mask_file)
    else:
        dst_mask_file = None
    # resample data files
    dst_files_list = []
    for src_file_path in src_files_list:
        dst_file_path = os.path.join(
            temp_folder_path, os.path.basename(src_file_path))
        if os.path.exists(dst_file_path):
            os.remove(dst_file_path)
        resample_raster(reg_dic, src_file_path, dst_file_path)
        dst_files_list.append(dst_file_path)
    return dst_mask_file, dst_files_list


def resample_raster(reg_dic, src_filename, dst_filename):
    """
    Resample raster file by region dictionary.

    Arguments:
    reg_dic -- Region parameters.
    src_filename -- Input raster filename.
    dst_filename -- Output raster filename.
    """
    o_fmt = get_gdal_driver_name(dst_filename)
    cmd = ("gdalwarp -te " +
           str(reg_dic["MinimumLongitude"]) + " " +
           str(reg_dic["MinimumLatitude"]) + " " +
           str(reg_dic["MaximumLongitude"]) + " " +
           str(reg_dic["MaximumLatitude"]) + " " +
           " -tr " + str(reg_dic["Height"]) + " " +
           str(-reg_dic["Height"]) + " " +
           " -of " + o_fmt  + " -tap " +
           " -t_srs " + config.DEFAULT_CRS + " " +
           handle_subprocess_call_path(src_filename) + " " +
           handle_subprocess_call_path(dst_filename))
    subprocess.call(cmd, shell=True)


def set_nodata(src_filename, dst_filename, nd_val):
    """
    Set the nodata for a file.

    Arguments:
    src_filename -- Input file name.
    dst_filename -- Output file name.
    nd_val -- No data value as a string.
    """
    o_fmt = get_gdal_driver_name(dst_filename)
    cmd = ("gdal_translate -a_nodata " + nd_val +
           " -of " + o_fmt + "  " +
           handle_subprocess_call_path(src_filename) + " " +
           handle_subprocess_call_path(dst_filename))
    subprocess.call(cmd, shell=True)


def translate_datatype(src_filename, dst_filename, dst_data_type):
    """
    Change file to specified datatypeformat.
    params(string) - src_file - Input file name.
    params(string) - dst_file - Output file name.
    param(string) - data_type - Destination file data type
    """
    cmd = ("gdal_translate \"" +
           handle_subprocess_call_path(src_filename) + "\" \"" +
           handle_subprocess_call_path(dst_filename) +
           "\" -ot " + dst_data_type +
           " -a_nodata none")
    subprocess.call(cmd, shell=True)


def write_file(dst_filename, src_array, col_ct, row_ct,
               geoxform, data_type=gdal.GDT_Int16):
    """
    Function to write an array to an output file.

    Arguments:
    dst_filename - The name of the output file.
    src_array - The array to write out to the file.
    col_ct - Number of columns in the array.
    row_ct - Number of rows in the array.
    geoxform - Geotransform for the output file.
    data_type - gdal constant for data type.
    Returns:
    err -- True if error, else false.
    """
    err = True
    try:
        os.makedirs(os.path.dirname(dst_filename), exist_ok = True)
        drvr = get_gdal_driver_object(dst_filename)
        drvr.Register()
        try:
            if os.path.exists(dst_filename):
                os.remove(dst_filename)
        except OSError as error:
            QgsMessageLog.logMessage(
                "Exception - Destination file is open: " + str(dst_filename),
                level=Qgis.Critical)
            raise OSError from error

        dst_ds = drvr.Create(dst_filename, col_ct, row_ct, 1, data_type)
        dst_ds.SetGeoTransform(geoxform)
        dst_ds.GetRasterBand(BAND_NUM).WriteArray(src_array)
        dst_ds = None

        # check for successful creation
        check_dst_file = gdal.Open(dst_filename, gdal.GA_ReadOnly)
        if check_dst_file is None:
            raise IOError(u"Error writing file:  " + dst_filename)
        err = False

    except BaseException:
        QgsMessageLog.logMessage(
            "Exception - Unspecified error when writing file",
            level=Qgis.Critical)
    return err
