"""
/***************************************************************************
Name       :  geowrsi_climatological_analysis_worker.py
Description:  Geowrsi climatological analysis worker for FEWSTools plugin
copyright  :  (C) 2022-2023 by FEWS
email      :  jhowton@contractor.usgs.gov
Created    :  10/17/2022 - Jacob Howton
Modified   :
 ***************************************************************************/
"""

import os
import numpy as np

from PyQt5 import QtCore
from qgis.core import QgsMessageLog, Qgis

from fews_tools.utilities import climatological_analysis_utilities as ca_util
from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geowrsi_utilities as geowrsi_util
from fews_tools.utilities import geoclim_utilities as util
from fews_tools.models.workspace_setup_model import WorkspaceSetupModel


class GeoWRSIClimatologicalAnalysisWorker(QtCore.QObject):
    '''
    Worker class for the GeoWRSI climatological analysis tool
    '''

    CLIM_ANAL_METHODS = \
        ["Average", "Median", "Standard Deviation", "Range", "Count",
         "Coefficient of Variation", "Trend", "Percentiles"]

    def __init__(
            self,
            analysis_method,
            parameter,
            smart_trend,
            min_r2,
            percentile,
            year_list,
            files_to_analyze,
            output_path,
            reg_dic):

        QtCore.QObject.__init__(self)
        self.wrksp_setup = WorkspaceSetupModel()

        self.output_path = output_path
        self.analysis_method = analysis_method
        self.parameter = parameter
        self.smart_trend = smart_trend
        self.min_r2 = min_r2
        self.percentile = percentile

        self.year_list = year_list
        self.files_to_analyze = files_to_analyze
        self.reg_dic = reg_dic

        self.killed = False
        self.step = 0.0
        self.curr_progress = 0.0

        self.row_ct = 0
        self.col_ct = 0
        self.geoxfrm = None
        self.gd_data_type = None
        self.np_data_type = None

    def update_progress(self):
        '''
        Helper for progress bar updates
        '''
        self.curr_progress += self.step
        self.curr_progress = min(self.curr_progress, 90)
        self.progress.emit(int(self.curr_progress))

    def run(self):
        '''
        Run function for GeoWRSI climatological analysis worker object.
        '''
        ret_tuple = None
        self.progress.emit(10)
        util.remove_temp_files(self.wrksp_setup.get_output_path())
        try:
            if self.killed is True:
                raise KeyboardInterrupt

            QgsMessageLog.logMessage("Begin GeoWRSI climatological analysis",
                                     level=Qgis.Info)

            stat_array = None
            stat_array_intcp = None
            stat_array_r2 = None
            data_cube = geowrsi_util.fill_data_cube_from_files(
                self.reg_dic, self.files_to_analyze,
                self.wrksp_setup.get_output_path())

            if self.killed is True:
                raise KeyboardInterrupt
            if self.parameter == "WRSI":
                data_cube_masked = np.ma.masked_array(data_cube,
                                                      mask=((data_cube == 253) |
                                                            (data_cube == 254) |
                                                            (data_cube == 255) |
                                                            (data_cube == -9999) |
                                                            (np.isnan(data_cube))))

            else:
                data_cube_masked = np.ma.masked_array(data_cube,
                                                      mask=((data_cube == 60) |
                                                            (data_cube == -9999) |
                                                            (np.isnan(data_cube))))

            QgsMessageLog.logMessage("Calculating " + self.analysis_method,
                                     level=Qgis.Info)

            if self.killed is True:
                raise KeyboardInterrupt
            trend_params_dic = {
                "yr_list": self.year_list,
                "slp_convert": False,  # The standalone doesn't have this option
                "smart": self.smart_trend,
                "min_r2": self.min_r2}

            if self.killed is True:
                raise KeyboardInterrupt
            if self.analysis_method == self.CLIM_ANAL_METHODS[0]:  # Average
                stat_array = np.ma.average(data_cube_masked, axis=0)

            elif self.analysis_method == self.CLIM_ANAL_METHODS[1]:  # Median
                stat_array = np.ma.median(data_cube_masked, axis=0)

            # Standard Deviation
            elif self.analysis_method == self.CLIM_ANAL_METHODS[2]:
                stat_array = np.ma.std(data_cube_masked, axis=0)

            elif self.analysis_method == self.CLIM_ANAL_METHODS[3]:  # Range
                stat_array = ca_util.range_analysis(data_cube_masked).astype(np.float32)
                stat_array[np.isnan(data_cube[0])] = np.nan

            elif self.analysis_method == self.CLIM_ANAL_METHODS[4]:  # Count
                stat_array = np.ma.count(data_cube_masked, axis=0).astype(np.float32)
                stat_array[np.isnan(data_cube[0])] = np.nan

            # Coefficient of Variation
            elif self.analysis_method == self.CLIM_ANAL_METHODS[5]:
                std_dev_array = np.ma.std(data_cube_masked, axis=0)
                average_array = np.ma.average(data_cube_masked, axis=0)
                stat_array = ca_util.coefficient_of_variation_calculation(
                    average_array, std_dev_array, np.nan)

            elif self.analysis_method == self.CLIM_ANAL_METHODS[6]:  # Trend
                stat_array, stat_array_intcp, stat_array_r2 = \
                    np.apply_along_axis(
                        ca_util.trend_function_masked, 0, data_cube_masked,
                        trend_params_dic, np.nan)
                stat_array[np.isnan(data_cube[0])] = -9999

            # Percentiles
            elif self.analysis_method == self.CLIM_ANAL_METHODS[7]:

                stat_array = np.percentile(data_cube, self.percentile, axis=0)

            if self.killed is True:
                raise KeyboardInterrupt
            self.progress.emit(60)

            # Percentiles
            if self.analysis_method not in self.CLIM_ANAL_METHODS[7]:
                # If there were only masked values, set a new one
                for x in range(stat_array.shape[0]):
                    for y in range(stat_array.shape[1]):
                        if np.ma.is_masked(stat_array[x, y]):
                            if 253 in data_cube[:, x, y]:
                                stat_array[x, y] = 253

                            else:
                                stat_array[x, y] = data_cube[0, x, y]

                        if stat_array_intcp is not None:
                            if np.ma.is_masked(stat_array_intcp[x, y]):
                                stat_array_intcp[x, y] = -9999

                        if stat_array_r2 is not None:
                            if np.ma.is_masked(stat_array_r2[x, y]):
                                stat_array_r2[x, y] = -9999

            _, self.row_ct, self.col_ct, self.geoxfrm, data_type =\
                g_util.get_geotiff_info(self.files_to_analyze[0])
            self.gd_data_type = g_util.TYPE_DIC[data_type]["GDAL"]

            if self.killed is True:
                raise KeyboardInterrupt
            self.progress.emit(80)

            if self.analysis_method not in self.CLIM_ANAL_METHODS[6]:
                err = g_util.write_file(
                    self.output_path, stat_array,
                    self.col_ct, self.row_ct, self.geoxfrm,
                    self.gd_data_type)

            else:
                err = g_util.write_file(
                    self.output_path, stat_array,
                    self.col_ct, self.row_ct, self.geoxfrm,
                    g_util.TYPE_DIC["Int32"]["GDAL"])

                output_file_name, output_file_extension = os.path.splitext(
                    self.output_path)

                err = g_util.write_file(
                    output_file_name + "intcp" + output_file_extension,
                    stat_array_intcp, self.col_ct, self.row_ct, self.geoxfrm,
                    g_util.TYPE_DIC["Int32"]["GDAL"])

                err = g_util.write_file(
                    output_file_name + "r2" + output_file_extension,
                    stat_array_r2, self.col_ct, self.row_ct, self.geoxfrm,
                    g_util.TYPE_DIC["Int32"]["GDAL"])

            if not err:
                QgsMessageLog.logMessage("Completed:  " + str(self.output_path),
                                         level=Qgis.Info)
            else:
                QgsMessageLog.logMessage(
                    "Error writing file", level=Qgis.Info)

            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Climatological Analysis complete")

        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, "Climatological Analysis aborted by user")
        # forward any execeptions upstream
        except BaseException as exc:
            self.error.emit(
                exc, "Unspecified error in Climatological Analysis")

        self.progress.emit(100)
        self.finished.emit(ret_tuple)

    def kill(self):
        '''
        Set the kill flag
        '''
        self.killed = True

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
