'''
/***************************************************************************
Name       :  climatological_analysis_utilities.py
Description:  Utility functions for the climatological analysis tool,
              separated from geoclim_utilities.py to improve maintainability
copyright  :  (C) 2021-2023 by FEWS
email      :  minxuansun@contractor.usgs.gov
Created    :  03/23/2021 - cholen
Modified   :  06/30/2021 - cholen - Removed LUT usage
              01/13/2022 - cholen - New gdal utils, refactor.
              09/02/2025 - dhackman - Fix for SPI to check if all valid values
                                    are the same so we can avoid a gamma.fit
                                    error, that we were seeing

 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import decimal

import numpy as np
from qgis.core import QgsMessageLog, Qgis

from scipy.stats import gamma
from scipy.stats import linregress
from scipy.stats import norm

from fews_tools.utilities import geoclim_utilities as util


def calc_pixel_spi(ppt_ssum_vals, missing_val,
                   min_pos_obs=12, norm_thresh=160.0):
    '''
    Calculate the SPI value array for a single pixel, output array will contain
    the mapped SPI values of all of the historical years(must include the spi
    year(s). This code is modified from version provided by Greg Husak.

    params(array) - ppt_ssum_vals - Historical ppt pixel timeseries.
    params(int) - missing_val - Missing data value.
    params(int) - min_pos_obs - Required number of historical, positive obs.
    params(float) - norm_thresh - Normalization threshold.
    return(float) - a - Alpha value.
    return(float) - b - Beta value.
    return(array) - pix_norm_prob_arr - Normalized probability array as a numpy array.
    return(array - pix_spi_arr - Spi array as numpy array.
    '''
    ts_arr = np.reshape(ppt_ssum_vals, len(ppt_ssum_vals))
    if missing_val in ts_arr:
        good_period_ct = np.sum(ts_arr != missing_val)
    else:
        good_period_ct = len(ts_arr)
    p_norain = np.sum(ts_arr == 0.00) / good_period_ct
    poslocs = np.where(ts_arr > 0.000)  # this filters 0 and no data(-9999)
    posvals = ts_arr[poslocs]
    if len(posvals) < min_pos_obs:
        alpha = 0
        beta = 0
        pix_norm_prob_arr = np.zeros(len(ppt_ssum_vals))
        pix_spi_arr = np.zeros(len(ppt_ssum_vals))
    elif len(set(posvals)) == 1:
        alpha = 0
        beta = 0
        x_i = 0.5
        px_i = np.zeros(len(ts_arr))
        px_i[poslocs] = x_i
        prob = p_norain + ((1.0 - p_norain) * px_i)
        if p_norain > 0.5:
            for i in np.where(ts_arr < 7):
                prob[i] = 0.5
        if np.sum(prob >= 1.0000) > 0:
            prob[np.where(prob >= 1.000)] = 0.9999999
        pix_norm_prob_arr = norm.ppf(prob)
        pix_spi_arr = np.array([entry * 100 for entry in pix_norm_prob_arr])
    else:
        alpha, loc1, beta = gamma.fit(posvals, floc=0.0)
        x_i = np.zeros(len(posvals))
        if alpha <= norm_thresh:
            x_i = gamma.cdf(posvals, alpha, loc=loc1, scale=beta)
        else:
            x_i = norm.cdf(
                posvals, loc=np.mean(posvals), scale=np.std(posvals))
        px_i = np.zeros(len(ts_arr))
        px_i[poslocs] = x_i
        prob = p_norain + ((1.0 - p_norain) * px_i)
        if p_norain > 0.5:
            for i in np.where(ts_arr < 7):
                prob[i] = 0.5
        if np.sum(prob >= 1.0000) > 0:
            prob[np.where(prob >= 1.000)] = 0.9999999
        pix_norm_prob_arr = norm.ppf(prob)
        pix_spi_arr = np.array([entry * 100 for entry in pix_norm_prob_arr])
    return alpha, beta, pix_norm_prob_arr, pix_spi_arr


def coefficient_of_variation_calculation(avg_array, std_dev_array, nd_val):
    '''
    Calculate coefficient of variation calculation on avg and std dev
    arrays.
    Args:
        avg_array(array) - Average values array.
        std_dev_array(array) - Standard deviation values array.
        nd_val(int) - Missing data value
    Returns:
        cv_array(array(integer)) - Output coefficient of variation value array.
    '''
    cv_array = np.full(avg_array.shape, nd_val)
    cv_math_help_v = np.vectorize(cv_math_help)
    cv_array = cv_math_help_v(avg_array, std_dev_array, nd_val)
    return cv_array


def count_analysis(src_data_cube, mask_array, nd_val):
    '''
    Function to do count analysis on selected files.
    Args:
        src_data_cube(3d np array) - src data.
        mask_array(2d np array) - Mask array
        nd_val(int) - Missing data value
    Returns:
        count_array(array(integer)) - Output array of counts.
    '''
    count_array = np.apply_along_axis(
        count_function, 0, src_data_cube, nd_val)
    count_array = np.ma.masked_array(count_array, mask=mask_array,
                                     fill_value=nd_val).filled()
    return count_array


def count_function(src_array, nd_val):
    '''
    Does the count calculation for an array
    params(1d numpy array) - src_array
    params(float) - Nodata value
    returns(int) - count
    '''
    count = len(src_array[np.where(src_array != nd_val)])
    # If count is zero, set the value to no data value
    if count == 0:
        count = nd_val
    return count


def count_range_function(src_array, f_min, f_max, nd_val):
    '''
    Gets the number of elements between and including min and max values, from an array
    Args:
        src_array(1d numpy array) - Src data.
        f_min(int) - Frequency minimum.
        f_max(int) - Frequency maximum.
        nd_val(int) - No data value.
    Returns:
        freq(int) - Count of elements between min and max values.
    '''
    if nd_val not in src_array:
        freq = src_array[np.where(f_min <= src_array)]
        freq = freq[np.where(freq <= f_max)]
        freq = len(freq)
    else:
        freq = nd_val
    return freq


def cv_math_help(avg, std_dev, nd_val):
    '''
    Helper function for coefficient of variation calculation.
    Args:
        avg(float) - Average value.
        std_dev(float) - Standard deviation value.
        nd_val(int) - Missing data value.
    Returns:
        ret_val - integer - cv value.
    '''
    if avg != 0 and avg != nd_val and std_dev != nd_val:
        ret_val = int(std_dev * 100 / avg)
    else:
        ret_val = nd_val
    return ret_val


def float_2_decimal_2d(val):
    '''
    Converts a float to a 2 decimal "Decimal" type
    To handle problems comparing float types
    params(float) - val - Source value
    returns(float) - ret_val - Decimal value of input

    '''
    return decimal.Decimal(val).quantize(
        decimal.Decimal('0.01'), rounding=decimal.ROUND_HALF_UP)


def frequency_analysis(src_data_cube, mask_array, freq_min, freq_max, nd_val):
    '''
    Function to do frequency analysis on selected files.
    Args:
        src_data_cube(3d numpy array) - Src data.
        mask_array(2d numpy array) - Mask array.
        freq_min(int) - Frequency minimum.
        freq_max(int) - Frequency maximum.
        nd_val(int) - No data value.
    Returns:
        frequency_array(2d numpy array)
    '''
    frequency_array = np.apply_along_axis(
        frequency_function, 0, src_data_cube, freq_min, freq_max, nd_val)
    frequency_array = np.ma.masked_array(
        frequency_array, mask=mask_array,
        fill_value=float(nd_val)).filled()
    return frequency_array


def count_range_analysis(src_data_cube, mask_array, freq_min, freq_max, nd_val):
    '''
    Function to do count analysis for certain range on selected files.
    Args:
        src_data_cube(3d numpy array) - Src data.
        mask_array(2d numpy array) - Mask array.
        freq_min(int) - Frequency minimum.
        freq_max(int) - Frequency maximum.
        nd_val(int) - No data value.
    Returns:
        frequency_array(2d numpy array)
    '''
    count_array = np.apply_along_axis(
        count_range_function, 0, src_data_cube, freq_min, freq_max, nd_val)
    count_array = np.ma.masked_array(
        count_array, mask=mask_array,
        fill_value=float(nd_val)).filled()
    return count_array


def frequency_function(src_array, f_min, f_max, nd_val):
    '''
    Gets the percentage of number of elements between and including min and max values, from an array
    Args:
        src_array(1d numpy array) - Src data.
        f_min(int) - Frequency minimum.
        f_max(int) - Frequency maximum.
        nd_val(int) - No data value.
    Returns:
        freq(int) - Count of elements between min and max values.
    '''
    if nd_val not in src_array:
        freq = src_array[np.where(f_min <= src_array)]
        freq = freq[np.where(freq <= f_max)]
        freq = round(len(freq) / len(src_array) * 100.0)
    else:
        freq = nd_val
    return freq


def get_idx_of_spi_years(pdf_years, spi_years):
    '''
    this expects years to be sorted, they need to go into
    the datacube in order so we can tell which idx are the spi
    years
    '''
    pdf_years_l = sorted(pdf_years)
    spi_years_l = sorted(spi_years)
    spi_idx_list = []
    for year in pdf_years_l:
        if year in spi_years:
            spi_idx_list.append(pdf_years_l.index(year))
    count = spi_idx_list[-1]
    # any spi years that are not in the pdf are added at the end
    for year in spi_years_l:
        if year not in pdf_years_l:
            count += 1
            spi_idx_list.append(count)
    return spi_idx_list


def spi_analysis(src_data_cube, mask_array,
                 pdf_year_list, spi_year_list, nd_val):
    '''
    loops through pixels to calculate the SPI values
    param(dic) - ds_dic - Dataset info
    param(list) - pdf_file_list - PDF seasonal sum file list
    param(list) - spi_file_list - SPI file list
    param(list) - pdf_year_list - PPF year list
    param(list) - spi_year_list - SPI year list
    param(array) - mask_array - Mask array

    returns(np array) - alpha array
    returns(np array) - beta array
    returns(np array) -  Probability output cube
    returns(np array) - SPI output cube
    '''
    a_array = []
    b_array = []
    pr_array = []
    spi_array = []
    try:
        spi_idx = get_idx_of_spi_years(pdf_year_list, spi_year_list)
        # setup outputs
        a_array = np.full(mask_array.shape, float(nd_val))
        b_array = np.full(mask_array.shape, float(nd_val))

        output_cube_shape = (
            len(spi_year_list), mask_array.shape[0], mask_array.shape[1])
        pr_array = np.full(output_cube_shape, -1.0)
        spi_array = np.full(
            output_cube_shape, float(nd_val))
        for row in range(mask_array.shape[0]):
            for col in range(mask_array.shape[1]):
                if mask_array[row, col] == 0:
                    # get ts for pixel
                    ppt_ssum_vals = src_data_cube[:, row, col]
                    # this will bring back arrays the length of the datacube
                    (a_array[row, col], b_array[row, col], pr_pix, spi_pix) = \
                        calc_pixel_spi(ppt_ssum_vals, nd_val)
                    # from the np arrays pr and spi, we need to extract the
                    # correct idxs to fill in for the selected spi years
                    pr_array[:, row, col] = pr_pix[spi_idx]
                    spi_array[:, row, col] = spi_pix[spi_idx]
    except BaseException as ex:
        QgsMessageLog.logMessage(
                        str(ex), level=Qgis.Info)
        a_array, b_array, pr_array, spi_array = None, None, None, None
    return a_array, b_array, pr_array, spi_array


def trend_analysis(src_data_cube, mask_array, trend_params_dic, nd_val):
    '''
    Trend analysis. Adapted from Windows GeoCLIM version (Image Regression)
    written by T. Tamuka Magadzire, USGS/FEWSNET.
    Args:
        src_data_cube(3d numpy array) - Input data.
        mask_array(2d numpy array) - Mask array.
        trend_params_dic(dic) - Trend parameters from gui
        nd_val(int) - Missing data value.
    Returns:
        slope_array, intcp_array, rsq_array
    '''
    nd_val = float(nd_val)
    # x array is always the years, which don't change, no need to put in loop
    x_array = [float(entry) for entry in sorted(trend_params_dic["yr_list"])]
    # handle cross years(drop the last year)
    if src_data_cube.shape[0] != len(x_array):
        x_array = x_array[:-1]
    slope_array, intcp_array, rsq_array = np.apply_along_axis(
        trend_function, 0, src_data_cube, x_array, trend_params_dic, nd_val)
    slope_array = np.ma.masked_array(
        slope_array, mask=mask_array, fill_value=nd_val).filled()
    intcp_array = np.ma.masked_array(
        intcp_array, mask=mask_array, fill_value=nd_val).filled()
    rsq_array = np.ma.masked_array(
        rsq_array, mask=mask_array, fill_value=nd_val).filled()
    return slope_array, intcp_array, rsq_array


def trend_function(y_array, x_array, trend_params_dic, nd_val):
    '''
    Does the trend calculations for an array
    params(1d numpy array) - y_array
    params(1d numpy array) - x_array
    params(dic) - Trend parameters dictionary
    params(float) - Nodata value
    returns(tuple(floats)) - Slope, intercept and r squared values
    '''
    slope = nd_val
    intcp = nd_val
    r_squared = nd_val
    if nd_val not in y_array:
        stats_dic = \
            util.weighted_least_squares_simple_linear_list(
                x_array, y_array, w_list=None,
                missing_val=nd_val)
        slope, intcp = linregress(x_array, y_array)[0:2]
        if slope != nd_val:
            slope = slope * 100
            if intcp != nd_val:
                intcp = intcp * 100
            if stats_dic['res_squared'] != nd_val:
                r_squared = stats_dic['res_squared'] * 100

            # if the dataset type is ppt or pet, convert slope*100 back
            # to slope(divide by 100) and multiply by 10 for mm per
            # decade (net result is divide by 10)
            if trend_params_dic["slp_convert"] is True:
                slope /= 10.0
            if trend_params_dic["smart"] is True and \
                    round(r_squared) < trend_params_dic["min_r2"]:
                slope = nd_val
    return slope, intcp, r_squared


def trend_function_masked(y_array, trend_params_dic, nd_val):
    '''
    Does the trend calculations for an array
    params(1d numpy array) - y_array
    params(dic) - Trend parameters dictionary
    params(float) - Nodata value
    returns(tuple(floats)) - Slope, intercept and r squared values
    '''
    slope = nd_val
    intcp = nd_val
    r_squared = nd_val
    y_array = y_array.compressed()

    x_array = range(len(y_array))

    if len(y_array) != 0:
        # The vb did this, I am not sure if we want to
        if np.count_nonzero(y_array == y_array[0]) == len(y_array):
            slope = nd_val
        elif nd_val not in y_array:
            stats_dic = \
                util.weighted_least_squares_simple_linear_list(
                    x_array, y_array, w_list=None,
                    missing_val=nd_val)
            slope, intcp = linregress(x_array, y_array)[0:2]
            if slope != nd_val:
                slope = slope * 100
                if intcp != nd_val:
                    intcp = intcp * 100
                if stats_dic['res_squared'] != nd_val:
                    r_squared = stats_dic['res_squared'] * 100

                # if the dataset type is ppt or pet, convert slope*100 back
                # to slope(divide by 100) and multiply by 10 for mm per
                # decade (net result is divide by 10)
                if trend_params_dic["slp_convert"] is True:
                    slope /= 10.0
                if trend_params_dic["smart"] is True and \
                        round(r_squared) < trend_params_dic["min_r2"]:
                    slope = nd_val

    return slope, intcp, r_squared


def range_analysis(src_data_cube):
    '''
    Does the range analysis on an array
    params(3d numpy array) - src_data_cube
    '''
    src_array = np.apply_along_axis(
        range_function, 0, src_data_cube)

    return src_array


def range_function(in_array):
    '''
    Does the range calculation for an array
    params(1d numpy array) - in_array
    '''
    if len(in_array.compressed()) == 0:
        out = 0
    else:
        out = np.max(in_array.compressed()) - np.min(in_array.compressed())
    return out
