'''
/***************************************************************************
Name       :  changes_in_averages_worker.py
Description:  Changes in averages worker for climate data
copyright  :  (C) 2020-2023 by FEWS
email      :  minxuansun@contractor.usgs.gov
Author     :  Austin Christianson
Modified   :  07/14/2020 cholen - Update exception
              11/02/2020 cholen - Updates for resample of inputs
              01/06/2021 cholen - Use new utilities, SCA cleanup
              06/23/2022 cholen - Fix path for region mask file
 ***************************************************************************/
'''
from PyQt5 import QtCore

from qgis.core import QgsMessageLog, Qgis
from fews_tools.models.workspace_setup_model import WorkspaceSetupModel
from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geoclim_utilities as util


class ChangesInAveragesWorker(QtCore.QObject):
    '''
    Changes in Averages worker class
    '''

    def __init__(self, ds_dic, reg_dic, jul_2_jun,
                 series_years_list_ts1, series_years_list_ts2,
                 selected_period_list, avg_ts1_filename, avg_ts2_filename,
                 output_path, output_filename):
        QtCore.QObject.__init__(self)
        self.ds_dic = ds_dic
        self.reg_dic = reg_dic
        self.wrksp_setup = WorkspaceSetupModel()
        self.series_years_list_ts1 = series_years_list_ts1
        self.series_years_list_ts2 = series_years_list_ts2
        self.selected_periods_list = selected_period_list
        self.avg_ts1_filename = avg_ts1_filename
        self.avg_ts2_filename = avg_ts2_filename
        self.output_path = output_path
        self.output_filename = output_filename
        self.beg_period = ""
        self.end_period = ""
        self.killed = False
        self.jul_2_jun = jul_2_jun
        self.step = 0.0
        self.curr_progress = 0.0
        # resample the mask
        self.mask_file, _ = g_util.resample_input_files(
            self.output_path, self.reg_dic['Mask'], [], self.reg_dic)
        self.row_ct = 0
        self.col_ct = 0
        self.geoxfrm = None
        self.gd_data_type = ""
        self.np_data_type = ""

    def calc_seasonal_sums(self):
        '''
        Calculate seasonal sum files.
        returns(boolean) err - True if error, else False.
        '''
        err = True
        # get sorted no-duplicates list of years
        process_list = sorted(set(self.series_years_list_ts1 +
                                  self.series_years_list_ts2))
        flag = False
        for entry in process_list:
            if not self.jul_2_jun:
                seasonal_sum_file_list, self.beg_period, self.end_period =\
                    util.get_input_info_for_seasonals(
                        self.ds_dic, entry, self.selected_periods_list)
            else:
                seasonal_sum_file_list, self.beg_period, self.end_period =\
                    util.get_input_info_for_cross_year_seasonals(
                        self.ds_dic, entry[0:4], self.selected_periods_list)
            _, seasonal_sum_file_list = g_util.resample_input_files(
                self.output_path, None, seasonal_sum_file_list, self.reg_dic)
            if flag is False:  # only need to do this once
                _, self.row_ct, self.col_ct, self.geoxfrm, data_type =\
                       g_util.get_geotiff_info(seasonal_sum_file_list[0])
                # sometimes we have inputs as byte type, but we also have nodata
                # as a negative number, the outputs will be int16 to
                # handle the negative number(USGSGDAS PET datasets are an example)
                if self.ds_dic['DATAMISSINGVALUE'] < 0 and data_type == "Byte":
                    data_type = "Int16"
                self.gd_data_type = g_util.TYPE_DIC[data_type]["GDAL"]
                self.np_data_type = g_util.TYPE_DIC[data_type]["NP"]
                flag = True
            dst_filename = util.get_seasonal_file_names(
                self.ds_dic, "", [entry],
                [self.beg_period, self.end_period], self.output_path)[0]

            # run the calculation
            data_cube = g_util.get_data_cube(seasonal_sum_file_list)
            full_mask_array = g_util.get_data_cube_and_region_mask(
                data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
            masked_sum_array = g_util.calc_masked_data_cube_sum(
                data_cube, full_mask_array, self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(dst_filename, masked_sum_array,
                                    self.col_ct, self.row_ct, self.geoxfrm,
                                    self.gd_data_type)
            self.curr_progress += self.step
            if self.killed is True or err is True:
                break
        data_cube = None
        masked_sum_array = None
        full_mask_array = None
        return err

    def kill(self):
        '''
        Set the kill flag.
        '''
        self.killed = True

    def run(self):
        '''
        Run function for average worker object.
        '''
        ret_tuple = None
        try:
            QgsMessageLog.logMessage(
                u'Begin Changes in Averages...', level=Qgis.Info)
            self.step =\
                90.0 / (len(self.series_years_list_ts1) +
                        len(self.series_years_list_ts2))
            self.curr_progress = 0
            # seasonal sums
            err = self.calc_seasonal_sums()
            self.progress.emit(int(self.curr_progress))
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError

            # get list of seasonal sums for ts1
            ts1_files_to_avg = util.get_seasonal_file_names(
                self.ds_dic, "", self.series_years_list_ts1,
                [self.beg_period, self.end_period], self.output_path)
            ts2_files_to_avg = util.get_seasonal_file_names(
                self.ds_dic, "", self.series_years_list_ts2,
                [self.beg_period, self.end_period], self.output_path)
            # average ts1
            data_cube = g_util.get_data_cube(ts1_files_to_avg)
            full_mask_array = g_util.get_data_cube_and_region_mask(
                data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
            dst_array1 = g_util.calc_masked_data_cube_np_stat(
                data_cube, full_mask_array, "Average",
                self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(
                self.avg_ts1_filename, dst_array1,
                self.col_ct, self.row_ct, self.geoxfrm, self.gd_data_type)
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError
            # average ts2
            data_cube = g_util.get_data_cube(ts2_files_to_avg)
            full_mask_array = g_util.get_data_cube_and_region_mask(
                data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
            dst_array2 = g_util.calc_masked_data_cube_np_stat(
                data_cube, full_mask_array, "Average",
                self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(
                self.avg_ts2_filename, dst_array2,
                self.col_ct, self.row_ct, self.geoxfrm, self.gd_data_type)
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError
            QgsMessageLog.logMessage(
                u'Time Series Average files complete', level=Qgis.Info)
            # create the diff
            mask_array = g_util.get_inverted_mask_array(self.mask_file)
            dst_array = g_util.calc_masked_array_diff(
                dst_array2, dst_array1, mask_array,
                self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(
                self.output_filename, dst_array,
                self.col_ct, self.row_ct, self.geoxfrm, self.gd_data_type)
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError
            self.progress.emit(90)
            QgsMessageLog.logMessage(
                u'Differences complete', level=Qgis.Info)
            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Changes in Averages complete")
        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, u"Changes in Averages aborted by user")
        # forward any execeptions upstream
        except BaseException as exc:
            self.error.emit(exc, u'Unspecified error in Changes in Averages')
        data_cube = None
        dst_array1, dst_array2, dst_array = None, None, None
        full_mask_array, mask_array = None, None
        self.finished.emit(ret_tuple)

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
