"""
/***************************************************************************
Name	   :  climatological_analysis_worker.py
Description:  climatological analysis worker for FEWSTools plugin,
              updated from QGIS2
copyright  :  (C) 2019-2023 by FEWS
email      :  minxuansun@contractor.usgs.gov
Created    :  02/24/2020 - Minxuan Sun
Modified   :  12/03/2020 - cholen - Handle OSError
              01/05/2021 - cholen - Formatting cleanup add seasonal average
                           capability for temperature datasets
              01/13/2021 - cholen - Handle regex problems
              01/19/2021 - cholen - Fix bug for cross year seasonals
              04/09/2021 - cholen - Cleanup, use ca_util, removed unused args
              01/13/2022 - cholen - New gdal utils, refactor.
              03/22/2022 - cholen - Add tiff support.
              05/27/2022 - cholen - Fix alpha, beta, prob of rf suffixes
              06/23/2022 - cholen - Fix path for region mask file
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
import os
import re

from PyQt5 import QtCore

from qgis.core import QgsMessageLog, Qgis

from fews_tools.models.workspace_setup_model import WorkspaceSetupModel
from fews_tools.utilities import climatological_analysis_utilities as ca_util
from fews_tools import fews_tools_config as config
from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geoclim_utilities as util


class ClimatologicalAnalysisWorker(QtCore.QObject):
    '''
    Worker class for the climatological analysis tool
    '''
    MISSING_PCTL = -1

    CLIM_ANAL_METHODS = \
        ["Average", "Median", "Standard Deviation", "Count",
         "Coefficient of Variation", "Trend", "Percentiles",
         "Frequency", "SPI"]

    def __init__(
            self,
            current_dataset_info,
            current_region_info,
            output_path,
            analysis_method,
            season_total,
            season_type,
            jul_2_jun,
            dataset_averages,
            smart_trends,
            min_r2,
            percentile,
            min_frequency,
            max_frequency,
            selected_period_list,
            selected_year_list,
            selected_spi_year_list,
            good_file_list,
            output_file_list):
        QtCore.QObject.__init__(self)
        self.wrksp_setup = WorkspaceSetupModel()
        self.ds_dic = current_dataset_info
        self.reg_dic = current_region_info
        self.output_path = output_path
        self.analysis_method = analysis_method
        self.season_total = season_total
        self.season_type = season_type
        self.jul_2_jun = jul_2_jun
        self.dataset_average = dataset_averages
        self.smart_trends = smart_trends
        self.min_r2 = min_r2
        self.percentile = percentile
        self.min_frequency = min_frequency
        self.max_frequency = max_frequency
        self.selected_periods_list = selected_period_list
        self.selected_year_list = selected_year_list
        self.selected_spi_year_list = selected_spi_year_list
        self.good_files_list = good_file_list
        self.output_file_list = output_file_list
        self.seasonal_total_file_list = None
        self.killed = False
        self.step = 0.0
        self.curr_progress = 0.0

        # resample the mask
        self.mask_file, _ = g_util.resample_input_files(
            self.output_path, self.reg_dic['Mask'], [], self.reg_dic)
        self.row_ct = 0
        self.col_ct = 0
        self.geoxfrm = None
        self.gd_data_type = None
        self.np_data_type = None

    def update_progress(self):
        '''
        Helper for progress bar updates
        '''
        self.curr_progress += self.step
        if self.curr_progress > 90:
            self.curr_progress = 90
        self.progress.emit(int(self.curr_progress))

    def get_output_files(self, interval):
        """
        Gets output file names for interval.

        Returns:
        output_file
        otuput_file_int
        output_file_r2

        """
        output_file = ""
        output_file_int = ""
        output_file_r2 = ""
        interval_string = util.get_key_from_period_value(
            self.ds_dic, interval)
        for entry in self.output_file_list:
            if self.analysis_method != self.CLIM_ANAL_METHODS[5]:
                # Not Trend
                if interval_string in entry:
                    output_file = entry
            else:
                # Trend has three output files
                if interval_string in entry and 'slp' in entry:
                    output_file = entry
                if interval_string in entry and 'int' in entry:
                    output_file_int = entry
                if interval_string in entry and 'r2' in entry:
                    output_file_r2 = entry
        if output_file == "":
            raise RuntimeError
        return output_file, output_file_int, output_file_r2

    def get_process_list(self, interval):
        """
        Gets the process list for non-seasonals.

        Returns:
        process_list -- List of file to use for processing.

        """
        process_list = []
        for yr_entry in self.selected_year_list:
            patt = yr_entry + interval
            for entry in self.good_files_list:
                if patt in entry:
                    process_list.append(entry)
        process_list = sorted(set(process_list), key=None,
                              reverse=False)
        return process_list

    def run(self):
        '''
        Run function for climatological analysis worker object.
        '''
        ret_tuple = None
        try:
            QgsMessageLog.logMessage(u"Begin climatological analysis...",
                                     level=Qgis.Info)

            if self.season_total:
                self.step = int(90.0 / (len(self.selected_year_list)))
            else:
                self.step = int(90.0 / (len(self.selected_periods_list) *
                                        len(self.selected_year_list)))
            if not self.step:
                self.step = 1
            self.curr_progress = 0
            self.progress.emit(int(self.curr_progress))
            # calculate dataset averages if necessary
            # this must happen before any resampling of data occurs
            if self.dataset_average:
                msg = ("Climatological Analysis - Calculating dataset averages")
                QgsMessageLog.logMessage(msg, level=Qgis.Info)
                # average each period for selected years
                err = self.update_dataset_averages()
                if self.killed is True:
                    raise KeyboardInterrupt
                if err is True:
                    raise RuntimeError

            # resample data files
            msg = ("Climatological Analysis - Resampling inputs")
            QgsMessageLog.logMessage(msg, level=Qgis.Info)
            _, self.good_files_list = g_util.resample_input_files(
                self.output_path, None, self.good_files_list, self.reg_dic)

            _, self.row_ct, self.col_ct, self.geoxfrm, data_type =\
                g_util.get_geotiff_info(self.good_files_list[0])
            # sometimes we have inputs as byte type, but we also have nodata
            # as a negative number, the outputs will be int16 to
            # handle the negative number(USGSGDAS PET datasets are an example)
            if self.ds_dic['DATAMISSINGVALUE'] < 0 and data_type == "Byte":
                data_type = "Int16"
            self.gd_data_type = g_util.TYPE_DIC[data_type]["GDAL"]
            self.np_data_type = g_util.TYPE_DIC[data_type]["NP"]

            # Calculate seasonal sums if necessary
            if self.season_total:
                msg = ('Climatological Analysis - Calculating seasonals')
                QgsMessageLog.logMessage(msg, level=Qgis.Info)
                err = self.calculate_seasonals()
                if self.killed is True:
                    raise KeyboardInterrupt
                if err is True:
                    raise RuntimeError
                # build dictionary
                info_dic = {"00":
                            {"input_list": self.seasonal_total_file_list,
                             "output_list": self.output_file_list}}
            else:
                info_dic = {}
                # otherwise we really have a bigger dictionary
                for interval in self.selected_periods_list:
                    info_dic[interval] = {
                        "input_list": self.get_process_list(interval),
                        "output_list": self.get_output_files(interval)}

            msg = ('Climatological Analysis - Calculating ' +
                   self.analysis_method)
            QgsMessageLog.logMessage(msg, level=Qgis.Info)
            slp_convert = False
            # if dataset type if ppt or pet
            if self.ds_dic["DATATYPE"].lower() in [
                    config.DATA_TYPES[0][1].lower(),
                    config.DATA_TYPES[4][1].lower()]:
                slp_convert = True
            trend_params_dic = {
                "yr_list": self.selected_year_list,
                "slp_convert": slp_convert,
                "smart": self.smart_trends,
                "min_r2": self.min_r2}
            for val in info_dic.values():
                data_cube = g_util.get_data_cube(val["input_list"])
                full_mask_array = g_util.get_data_cube_and_region_mask(
                        data_cube, self.mask_file,
                        self.ds_dic['DATAMISSINGVALUE'])
                dst_array = None
                dst_array_frequency_count = None
                dst_array_intcp = None
                dst_array_r2 = None
                if (self.analysis_method in
                      [self.CLIM_ANAL_METHODS[0],
                       self.CLIM_ANAL_METHODS[1],
                       self.CLIM_ANAL_METHODS[2],
                       self.CLIM_ANAL_METHODS[6]]):
                    # Average, Median, Std deviation or percentile
                    dst_array = g_util.calc_masked_data_cube_np_stat(
                        data_cube, full_mask_array, self.analysis_method,
                        self.ds_dic['DATAMISSINGVALUE'], self.percentile)
                elif self.analysis_method == self.CLIM_ANAL_METHODS[3]:
                    # Count
                    dst_array = ca_util.count_analysis(
                         data_cube, full_mask_array,
                         self.ds_dic['DATAMISSINGVALUE'])
                elif self.analysis_method == self.CLIM_ANAL_METHODS[4]:
                    # Coefficient of Variation
                    std_dev_array = g_util.calc_masked_data_cube_np_stat(
                        data_cube, full_mask_array, self.CLIM_ANAL_METHODS[2],
                        self.ds_dic['DATAMISSINGVALUE'])
                    average_array = g_util.calc_masked_data_cube_np_stat(
                        data_cube, full_mask_array, self.CLIM_ANAL_METHODS[0],
                        self.ds_dic['DATAMISSINGVALUE'])
                    dst_array = ca_util.coefficient_of_variation_calculation(
                        average_array.round(),
                        std_dev_array.round(),
                        self.ds_dic['DATAMISSINGVALUE'])
                elif self.analysis_method == self.CLIM_ANAL_METHODS[5]:
                    # Trend
                    dst_array, dst_array_intcp, dst_array_r2 = \
                        ca_util.trend_analysis(
                            data_cube, full_mask_array,
                            trend_params_dic, self.ds_dic['DATAMISSINGVALUE'])
                elif self.analysis_method == self.CLIM_ANAL_METHODS[7]:
                    # Frequency
                    dst_array = ca_util.frequency_analysis(
                        data_cube, full_mask_array,
                        self.min_frequency, self.max_frequency,
                        self.ds_dic['DATAMISSINGVALUE'])
                    dst_array_frequency_count = ca_util.count_range_analysis(
                        data_cube, full_mask_array,
                        self.min_frequency, self.max_frequency,
                        self.ds_dic['DATAMISSINGVALUE'])
                elif self.analysis_method == self.CLIM_ANAL_METHODS[8]:
                    # think it is fixed just need retest
                    # currently spi years must be within pdf years, per Greg
                    # so no extra seasonal files need to be created
                    dst_beta_file = os.path.join(
                        self.output_path, config.BETA_FILENAME + self.ds_dic["DATASUFFIX"])
                    dst_alpha_file = os.path.join(
                        self.output_path, config.ALPHA_FILENAME + self.ds_dic["DATASUFFIX"])
                    dst_array = None
                    # there is a lt_array and spi_array for each selected pdf
                    # year, need to extract the spi indexes from that for
                    # saving outputs.....
                    mask_array = g_util.get_inverted_mask_array(self.mask_file)
                    alpha_array, beta_array, lt_array, spi_array =\
                        ca_util.spi_analysis(data_cube, mask_array,
                                             self.selected_year_list,
                                             self.selected_spi_year_list,
                                             self.ds_dic['DATAMISSINGVALUE'])
                    err = g_util.write_file(
                        dst_alpha_file, alpha_array,
                        self.col_ct, self.row_ct, self.geoxfrm,
                        g_util.TYPE_DIC["Float32"]["GDAL"])
                    if err is True:
                        raise RuntimeError
                    QgsMessageLog.logMessage(
                        u"Completed:  " + str(dst_alpha_file), level=Qgis.Info)
                    err = g_util.write_file(
                        dst_beta_file, beta_array,
                        self.col_ct, self.row_ct, self.geoxfrm,
                        g_util.TYPE_DIC["Float32"]["GDAL"])
                    if err is True:
                        raise RuntimeError
                    QgsMessageLog.logMessage(
                        u"Completed:  " + str(dst_beta_file), level=Qgis.Info)
                    # there should be a one to one correspondence between
                    # the output files and the year index of the output
                    # cubes, make sure we are sorted
                    for file_path in sorted(val["output_list"]):
                        idx = self.output_file_list.index(file_path)
                        # we have built the basename in the tool,
                        # so we can use reg exp and basename but have to
                        # split after SPI in case region name causes issues
                        try:
                            if 'YYYY' in self.ds_dic['DATADATEFORMAT']:
                                yr_reg_exp = r'\d{4}'
                            else:
                                yr_reg_exp = r'\d{2}'
                            temp_str = os.path.basename(
                                file_path).split('SPI_')[1]
                            yr_str = re.findall(yr_reg_exp, temp_str)[0]
                        except IndexError:
                            QgsMessageLog.logMessage(
                                u"Problem with name:  " + file_path,
                                level=Qgis.Critical)
                            raise IndexError
                        output_string = (yr_str +
                                         config.PROB_OF_RAIN_FILENAME +
                                         self.ds_dic["DATASUFFIX"])
                        lt_output_file_name = os.path.join(
                            self.output_path, output_string)
                        # even though LtOutputArray is a float,
                        # write it out as an integer(same as Windows)
                        err = g_util.write_file(
                            lt_output_file_name, lt_array[idx].round(),
                            self.col_ct, self.row_ct, self.geoxfrm,
                            self.gd_data_type)
                        if err is True:
                            raise RuntimeError
                        QgsMessageLog.logMessage(
                            u"Completed:  " + str(lt_output_file_name),
                            level=Qgis.Info)
                        err = g_util.write_file(
                            file_path, spi_array[idx].round(),
                            self.col_ct, self.row_ct, self.geoxfrm,
                            self.gd_data_type)
                        if err is True:
                            raise RuntimeError
                        QgsMessageLog.logMessage(
                            u"Completed:  " + str(file_path), level=Qgis.Info)
                if dst_array is not None:
                    err = g_util.write_file(
                        val["output_list"][0], dst_array,
                        self.col_ct, self.row_ct, self.geoxfrm,
                        self.gd_data_type)
                    if err is True:
                        raise RuntimeError
                    QgsMessageLog.logMessage(
                        u"Completed:  " + str(val["output_list"][0]),
                        level=Qgis.Info)
                # save extra frequency outputs
                if dst_array_frequency_count is not None:
                    err = g_util.write_file(
                        val["output_list"][1], dst_array_frequency_count,
                        self.col_ct, self.row_ct, self.geoxfrm,
                        self.gd_data_type)
                    if err is True:
                        raise RuntimeError
                    QgsMessageLog.logMessage(
                        u"Completed:  " + str(val["output_list"][1]),
                        level=Qgis.Info)
                # save extra Trend outputs
                if dst_array_intcp is not None:
                    err = g_util.write_file(
                        val["output_list"][1], dst_array_intcp,
                        self.col_ct, self.row_ct, self.geoxfrm,
                        g_util.TYPE_DIC["Int32"]["GDAL"])
                    if err is True:
                        raise RuntimeError
                    QgsMessageLog.logMessage(
                        u"Completed:  " + str(val["output_list"][1]),
                        level=Qgis.Info)
                if dst_array_r2 is not None:
                    err = g_util.write_file(
                        val["output_list"][2], dst_array_r2,
                        self.col_ct, self.row_ct, self.geoxfrm,
                        g_util.TYPE_DIC["Int32"]["GDAL"])
                    if err is True:
                        raise RuntimeError
                    QgsMessageLog.logMessage(
                        u"Completed:  " + str(val["output_list"][2]),
                        level=Qgis.Info)
                if self.killed is True:
                    raise KeyboardInterrupt
                if err is True:
                    raise RuntimeError
                self.update_progress()
            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Climatological Analysis complete")
        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, u"Climatological Analysis aborted by user")
        # forward any execeptions upstream
        except BaseException as exc:
            self.error.emit(exc, u"Unspecified error in Climatological Analysis")
        data_cube, dst_array = None, None
        dst_array_intcp, dst_array_r2 = None, None
        full_mask_array = None
        self.finished.emit(ret_tuple)

    def calculate_seasonals(self):
        '''
        Calculate seasonals. Resample of inputs is done here.
        returns(boolean) err - True if error, else False.
        '''
        self.seasonal_total_file_list = []
        reg_str = self.reg_dic['RegionName'].replace(' ', '')
        # get sorted no-duplicates list of years
        process_list = sorted(set(self.selected_year_list +
                                  self.selected_spi_year_list))
        for entry in process_list:
            if not self.jul_2_jun:
                seasonal_sum_inputs, beg_per, end_per =\
                    util.get_input_info_for_seasonals(
                        self.ds_dic, entry, self.selected_periods_list)
            else:
                seasonal_sum_inputs, beg_per, end_per =\
                    util.get_input_info_for_cross_year_seasonals(
                        self.ds_dic, entry[0:4], self.selected_periods_list)
            _, seasonal_sum_inputs = g_util.resample_input_files(
                self.output_path, None, seasonal_sum_inputs, self.reg_dic)
            # build the sum filenames
            dst_filename = util.get_seasonal_file_names(
                self.ds_dic, reg_str, [entry],
                [beg_per, end_per], self.output_path)[0]
            self.seasonal_total_file_list.append(dst_filename)
            # run the calculation
            data_cube = g_util.get_data_cube(seasonal_sum_inputs)
            full_mask_array = g_util.get_data_cube_and_region_mask(
                data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
            if self.season_type == "sum":
                dst_array = g_util.calc_masked_data_cube_sum(
                    data_cube, full_mask_array, self.ds_dic['DATAMISSINGVALUE'])
            else:  # avg
                dst_array = g_util.calc_masked_data_cube_np_stat(
                    data_cube, full_mask_array, "Average",
                    self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(dst_filename, dst_array,
                          self.col_ct, self.row_ct, self.geoxfrm,
                          self.gd_data_type)
            self.curr_progress += self.step
            if self.killed is True or err is True:
                break
            self.update_progress()
        data_cube, dst_array = None, None
        full_mask_array = None
        return err

    def update_dataset_averages(self):
        '''
        Update dataset averages.

        Returns:
        err -- True if err, else false.
        '''
        err = False
        interval_list = util.get_all_period_string(self.ds_dic)
        text_file = os.path.join(self.ds_dic['DATAFOLDER'], 'avgInfo.txt')
        with open(text_file, 'w') as output_text_file:
            output_text_file.write(
                "Dataset averages computed on years: {} through {}".format(
                    str(self.selected_year_list[0]),
                    str(self.selected_year_list[-1])))
        # adjust the progress bar step,
        # this will take the majority of the time
        self.step = 90.0 / len(interval_list)

        # calculate the dataset average
        flag = False
        for period in interval_list:
            ds_avg_dst_filename = os.path.join(
                self.ds_dic['AVGDATAFOLDER'],
                self.ds_dic['AVGDATAPREFIX'] + period +
                self.ds_dic["AVGDATASUFFIX"])
            # sometimes the removal of an old file needs a little time to
            # finish, delete existing file before doing calculation
            try:
                if os.path.exists(ds_avg_dst_filename):
                    os.remove(ds_avg_dst_filename)
                    QgsMessageLog.logMessage("Removed old:  " + str(ds_avg_dst_filename),
                        level=Qgis.Info)
            except OSError as error:
                QgsMessageLog.logMessage("Exception:  " + str(error.strerror),
                    level=Qgis.Critical)

            process_list = []
            for yr_entry in self.selected_year_list:
                temp_string = (self.ds_dic['DATAPREFIX'] + str(yr_entry) +
                               period + self.ds_dic["DATASUFFIX"])
                current_file = os.path.join(
                    self.ds_dic['DATAFOLDER'], temp_string)
                process_list.append(current_file)
                if flag is False:  # only need to do this once
                    _, ds_row_ct, ds_col_ct, ds_geoxfrm, ds_data_type =\
                        g_util.get_geotiff_info(current_file)
                    ds_gd_data_type = g_util.TYPE_DIC[ds_data_type]["GDAL"]
                    flag = True
            data_cube = g_util.get_data_cube(set(process_list))
            full_mask_array = g_util.get_data_cube_mask(
                data_cube, self.ds_dic['DATAMISSINGVALUE'])
            dst_array = g_util.calc_masked_data_cube_np_stat(
                data_cube, full_mask_array, "Average",
                self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(
                ds_avg_dst_filename, dst_array,
                ds_col_ct, ds_row_ct, ds_geoxfrm,
                ds_gd_data_type)
            if self.killed is True or err is True:
                break
            QgsMessageLog.logMessage(u"Completed:  " + str(ds_avg_dst_filename),
                                     level=Qgis.Info)
            self.update_progress()

        if self.curr_progress < 50.0:
            self.curr_progress = 50.0
            self.progress.emit(int(self.curr_progress))
        data_cube, dst_array = None, None
        full_mask_array = None
        return err

    def kill(self):
        '''
        Set the kill flag
        '''
        self.killed = True

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
