'''
/***************************************************************************
Name	   :  composites_worker.py
Description:  Composites Worker for FEWSTools plugin, updated from QGIS2
copyright  :  (C) 2019-2023 by FEWS
email      :  minxuansun@contractor.usgs.gov
Created    :  12/31/2019 CHOLEN
Modified   :  07/16/2020 CHOLEN - Emit error on exception
              11/03/2020 MSUN - Resample input raster based on region extents
                                and pixel sizes before processing
              11/06/2020 CHOLEN - Change pct of avg calc to use numpy
              12/03/2020 CHOLEN - Handle OSError
              12/14/2020 CHOLEN - Update formulas per Diego and Chris Funk
              04/12/2021 CHOLEN - Remove unused arg from calc_np_stat
              01/06/2022 CHOLEN - Use new gdal utilities, SCA cleanup
              02/23/2022 CHOLEN - Add tiff support
              06/23/2022 CHOLEN - Fix path for region mask file
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import datetime
import os

import numpy as np

from PyQt5 import QtCore

from qgis.core import QgsMessageLog, Qgis

from fews_tools.models.workspace_setup_model import WorkspaceSetupModel

from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geoclim_utilities as util


BAND_NUM = 1
FACTOR_1 = 0.1  # used to protect against divide by zero
FACTOR_05 = 0.05  # used for comp1 only percent of avg
FACTOR_P = 100  # to percentage


class CompositesWorker(QtCore.QObject):
    '''
    Composites Worker class
    '''
    def __init__(self, ds_dic, reg_dic, output_path,
                 analysis, jul_2_jun,
                 selected_periods_list, comp1_years_list,
                 comp2_years_list, baseline_years_list, available_years_list,
                 dst_filename,
                 comp1_dst_file, comp2_dst_file, comp_diff_dst_file,
                 baseline_avg, all_years_std, comp_analysis_list):
        QtCore.QObject.__init__(self)
        self.ds_dic = ds_dic
        self.reg_dic = reg_dic
        self.wrksp_setup = WorkspaceSetupModel()
        self.output_path = output_path
        self.analysis = analysis
        self.jul_2_jun = jul_2_jun
        self.selected_periods_list = selected_periods_list
        self.comp1_years_list = comp1_years_list
        self.comp2_years_list = comp2_years_list
        self.baseline_years_list = baseline_years_list
        self.available_years_list = available_years_list
        self.dst_filename = dst_filename
        self.comp1_dst_file = comp1_dst_file
        self.comp2_dst_file = comp2_dst_file
        self.comp_diff_dst_file = comp_diff_dst_file
        self.baseline_avg = baseline_avg
        self.all_years_std = all_years_std
        self.comp_analysis_list = comp_analysis_list
        self.killed = False
        self.beg_per = ''
        self.end_per = ''
        self.step = 0.0
        self.curr_progress = 0.0

        # resample the mask
        self.mask_file, _ = g_util.resample_input_files(
            self.output_path, self.reg_dic['Mask'], [], self.reg_dic)
        self.row_ct = 0
        self.col_ct = 0
        self.geoxfrm = None
        self.gd_data_type = ""
        self.np_data_type = ""

    def calc_anomaly(self):
        '''
        Calculate anomaly.
        returns(boolean) err - True if error, else False.
        '''
        QgsMessageLog.logMessage(u'Calculating Anomaly', level=Qgis.Info)
        file_list = [self.comp1_dst_file, self.baseline_avg ]
        t_val = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
        if self.comp2_years_list:
            temp_output_filename1 = os.path.join(
                self.output_path, 'tempA_' + t_val + self.ds_dic["DATASUFFIX"])
            temp_output_filename2 = os.path.join(
                self.output_path, 'tempB_' + t_val + self.ds_dic["DATASUFFIX"])
            file_list.append(self.comp2_dst_file)
        else:
            temp_output_filename1 = self.dst_filename
        comp1_array = g_util.extract_raster_array(self.comp1_dst_file)
        bl_array = g_util.extract_raster_array(self.baseline_avg)
        # get a data cube so we can get a mask
        data_cube = g_util.get_data_cube(file_list, self.np_data_type)
        full_mask_array = g_util.get_data_cube_and_region_mask(
            data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
        data_cube = None

        masked_diff1_array = g_util.calc_masked_array_diff(
            comp1_array, bl_array, full_mask_array,
            nd_val=self.ds_dic['DATAMISSINGVALUE'])
        err = g_util.write_file(temp_output_filename1, masked_diff1_array,
                                self.col_ct, self.row_ct, self.geoxfrm,
                                self.gd_data_type)
        if err is False and self.comp2_years_list:
            comp2_array = g_util.extract_raster_array(self.comp2_dst_file)
            masked_diff2_array = g_util.calc_masked_array_diff(
                comp2_array, bl_array, full_mask_array,
                nd_val=self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(temp_output_filename2, masked_diff2_array,
                                self.col_ct, self.row_ct, self.geoxfrm,
                                self.gd_data_type)
            if err is False:
                masked_diff3_array = g_util.calc_masked_array_diff(
                    masked_diff1_array, masked_diff2_array, full_mask_array,
                    nd_val=self.ds_dic['DATAMISSINGVALUE'])
                err = g_util.write_file(self.dst_filename, masked_diff3_array,
                                    self.col_ct, self.row_ct, self.geoxfrm,
                                    self.gd_data_type)
        self.curr_progress += self.step
        self.progress.emit(int(self.curr_progress))
        QgsMessageLog.logMessage(u'Anomaly complete', level=Qgis.Info)
        comp1_array, comp2_array = None, None
        bl_array = None
        full_mask_array = None
        return err

    def calc_average(self, years_list, dst_filename):
        '''
        Calculate average of seasonal sum files.
        params(list) - years_list - Years list.
        params(string) - dst_filename - Output filename.
        returns(boolean) err - True if error, else False.
        '''
        file_list = []
        dst_filename_l = self.wrksp_setup.fix_os_sep_in_path(dst_filename)
        temp_list = util.get_seasonal_file_names(
            self.ds_dic, "", years_list,
            [self.beg_per, self.end_per], self.output_path)
        for fl_entry in temp_list:
            b_name = os.path.basename(fl_entry)
            for entry in years_list:
                if entry in b_name:  # safe because b_name like pptsumYYYY...
                    file_list.append(
                        self.wrksp_setup.fix_os_sep_in_path(fl_entry))
                    break
        data_cube = g_util.get_data_cube(file_list)
        full_mask_array = g_util.get_data_cube_and_region_mask(
            data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
        mask_avg_array = g_util.calc_masked_data_cube_np_stat(
            data_cube, full_mask_array, "Average",
            self.ds_dic['DATAMISSINGVALUE'])

        err = g_util.write_file(
            dst_filename_l, mask_avg_array,
            self.col_ct, self.row_ct, self.geoxfrm,
            self.gd_data_type)
        self.curr_progress += self.step
        self.progress.emit(int(self.curr_progress))
        QgsMessageLog.logMessage(u'Average complete', level=Qgis.Info)
        data_cube, mask_avg_array = None, None
        full_mask_array = None
        return err

    def calc_all_available_years_std_dev_for_period(self):
        '''
        Calculate standard deviation of all available years.
        returns(boolean) err - True if error, else False.
        '''
        file_list = []
        temp_list = util.get_seasonal_file_names(
            self.ds_dic, "", self.available_years_list,
            [self.beg_per, self.end_per], self.output_path)

        for t_entry in temp_list:
            for entry in self.available_years_list:
                if entry in t_entry:
                    file_list.append(t_entry)
        # calculate the standard deviation
        data_cube = g_util.get_data_cube(file_list)
        full_mask_array = g_util.get_data_cube_and_region_mask(
            data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])

        std_dev_array = g_util.calc_masked_data_cube_np_stat(
            data_cube, full_mask_array, 'Standard Deviation',
            self.ds_dic['DATAMISSINGVALUE'])

        err = g_util.write_file(
            self.all_years_std, std_dev_array,
            self.col_ct, self.row_ct, self.geoxfrm,
            self.gd_data_type)
        self.curr_progress += self.step
        self.progress.emit(int(self.curr_progress))
        return err

    def calc_pct_of_average(self):
        '''
        Calculate percent of average.
        We protect against divide by zero by adding 0.1 to denominator,
        see documentation for formula details.
        returns(boolean) err - True if error, else False.
        '''
        QgsMessageLog.logMessage(u'Calculating Percent of Average',
                                 level=Qgis.Info)
        comp1_array = g_util.extract_raster_array(self.comp1_dst_file)
        bl_avg_array = g_util.extract_raster_array(self.baseline_avg)
        file_list = [self.comp1_dst_file, self.baseline_avg]
        if self.comp2_years_list:
            comp2_array = g_util.extract_raster_array(self.comp2_dst_file)
            lt_avg_array = np.full((self.row_ct, self.col_ct), 0)
            file_list.append(self.comp2_dst_file)
        else:
            comp2_array = np.full((self.row_ct, self.col_ct), 0)
            lt_avg_array = bl_avg_array  # this is only non-zero when no comp2
        # build a datacube so we can get the mask
        data_cube = g_util.get_data_cube(file_list, self.np_data_type)
        full_mask_array = g_util.get_data_cube_and_region_mask(
            data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
        data_cube = None
        dst_array =\
            (((comp1_array + (FACTOR_05 * lt_avg_array) - comp2_array + FACTOR_1) /
             (bl_avg_array + (FACTOR_05 * lt_avg_array) + FACTOR_1)) * FACTOR_P)
        dst_array = np.ma.masked_array(
            dst_array, mask=full_mask_array,
            fill_value=self.ds_dic['DATAMISSINGVALUE']).filled()
        err = g_util.write_file(self.comp_diff_dst_file, dst_array,
                                self.col_ct, self.row_ct, self.geoxfrm,
                                self.gd_data_type)
        self.curr_progress += self.step
        self.progress.emit(int(self.curr_progress))
        QgsMessageLog.logMessage(u'Percent of Average complete', level=Qgis.Info)
        comp1_array, comp2_array = None, None
        lt_avg_array, bl_avg_array = None, None
        full_mask_array = None
        return err

    def calc_seasonal_sums(self):
        '''
        Calculate seasonal sums. Resample of inputs is done here.
        returns(boolean) err - True if error, else False.
        '''
        # if standardized anomaly is selected,
        # we need to get all available years seasonal sums
        if self.analysis == self.comp_analysis_list[3]:
            process_list = sorted(set(self.available_years_list))
        else:
            # otherwise use gui year selections to get
            # sorted no-duplicates list of years
            process_list = sorted(set(self.baseline_years_list +
                                      self.comp1_years_list +
                                      self.comp2_years_list))
        flag = False
        for entry in process_list:
            if not self.jul_2_jun:
                seasonal_sum_file_list, self.beg_per, self.end_per =\
                    util.get_input_info_for_seasonals(
                        self.ds_dic, entry, self.selected_periods_list)
            else:
                seasonal_sum_file_list, self.beg_per, self.end_per =\
                    util.get_input_info_for_cross_year_seasonals(
                        self.ds_dic, entry[0:4], self.selected_periods_list)
            # resample the seasonal sum files
            _, seasonal_sum_file_list = \
                g_util.resample_input_files(self.output_path, None,
                                            seasonal_sum_file_list,
                                            self.reg_dic)
            if flag is False:  # only need to do this once
                _, self.row_ct, self.col_ct, self.geoxfrm, data_type =\
                        g_util.get_geotiff_info(seasonal_sum_file_list[0])
                # sometimes we have inputs as byte type, but we also have nodata
                # as a negative number, the outputs will be int16 to
                # handle the negative number(USGSGDAS PET datasets are an example)
                if self.ds_dic['DATAMISSINGVALUE'] < 0 and data_type == "Byte":
                    data_type = "Int16"
                self.gd_data_type = g_util.TYPE_DIC[data_type]["GDAL"]
                self.np_data_type = g_util.TYPE_DIC[data_type]["NP"]
                flag = True
            # build the sum filenames
            dst_filename = util.get_seasonal_file_names(
                self.ds_dic, "", [entry],
                [self.beg_per, self.end_per], self.output_path)[0]
            # run the calculation
            data_cube = g_util.get_data_cube(seasonal_sum_file_list)
            full_mask_array = g_util.get_data_cube_and_region_mask(
                data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
            masked_sum_array = g_util.calc_masked_data_cube_sum(
                data_cube, full_mask_array, self.ds_dic['DATAMISSINGVALUE'])
            err = g_util.write_file(dst_filename, masked_sum_array,
                          self.col_ct, self.row_ct, self.geoxfrm,
                          self.gd_data_type)
            self.curr_progress += self.step
            self.progress.emit(int(self.curr_progress))
            if self.killed is True or err is True:
                break
        data_cube, masked_sum_array = None, None
        full_mask_array = None
        return err

    def calc_standardized_anomaly(self):
        '''
        Calculate standardized anomaly.
        Because blstd is an integer type, we protect against divide
        by zero by adding 0.1
        # if no comp2,
        # output = (comp1 - blAvg + 0.1) / (blstd + 0.1) * 100
        # or if comp2,
        # tempoutput1 = (comp1 - blAvg + 0.1) / (blstd + 0.1) * 100
        # tempoutput2 = (comp2 - blAvg + 0.1) / (blstd + 0.1) * 100
        # output = tempoutput1 - tempoutput2
        returns(boolean) err - True if error, else False.
        '''
        t_val = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
        file_list = [self.comp1_dst_file, self.all_years_std, self.baseline_avg]
        if self.comp2_years_list:
            temp_output_filename1 = os.path.join(
                self.output_path, 'tempA_' + t_val + self.ds_dic["DATASUFFIX"])
            temp_output_filename2 = os.path.join(
                self.output_path, 'tempB_' + t_val + self.ds_dic["DATASUFFIX"])
            file_list.append(self.comp2_dst_file)
        else:
            temp_output_filename1 = self.dst_filename

        # build a datacube so we can get the mask
        data_cube = g_util.get_data_cube(file_list, self.np_data_type)
        full_mask_array = g_util.get_data_cube_and_region_mask(
            data_cube, self.mask_file, self.ds_dic['DATAMISSINGVALUE'])
        data_cube = None

        comp1_array = g_util.extract_raster_array(self.comp1_dst_file)
        bl_std_array = g_util.extract_raster_array(self.all_years_std)
        bl_avg_array = g_util.extract_raster_array(self.baseline_avg)

        temp_output_filename1_array =\
            (FACTOR_P * (comp1_array - bl_avg_array + FACTOR_1) /
                 (bl_std_array + FACTOR_1))
        temp_output_filename1_array = np.ma.masked_array(
            temp_output_filename1_array, mask=full_mask_array,
            fill_value=self.ds_dic['DATAMISSINGVALUE']).filled()
        err = g_util.write_file(temp_output_filename1,
                                temp_output_filename1_array,
                                self.col_ct, self.row_ct, self.geoxfrm,
                                self.gd_data_type)
        if err is False and self.comp2_years_list:
            comp2_array = g_util.extract_raster_array(self.comp2_dst_file)
            temp_output_filename2_array =\
                (FACTOR_P * (comp2_array - bl_avg_array + FACTOR_1) /
                     (bl_std_array + FACTOR_1))
            temp_output_filename2_array = np.ma.masked_array(
                temp_output_filename2_array, mask=full_mask_array,
                fill_value=self.ds_dic['DATAMISSINGVALUE']).filled()
            err = g_util.write_file(temp_output_filename2,
                                    temp_output_filename2_array,
                                    self.col_ct, self.row_ct, self.geoxfrm,
                                    self.gd_data_type)
            if err is False:
                masked_diff_array = g_util.calc_masked_array_diff(
                    temp_output_filename1_array, temp_output_filename2_array,
                    full_mask_array, nd_val=self.ds_dic['DATAMISSINGVALUE'])
                err = g_util.write_file(self.dst_filename, masked_diff_array,
                                        self.col_ct, self.row_ct, self.geoxfrm,
                                        self.gd_data_type)
        self.curr_progress += self.step
        self.progress.emit(int(self.curr_progress))
        QgsMessageLog.logMessage(u'Standardized Anomaly Complete', level=Qgis.Info)
        comp1_array, comp2_array = None, None
        bl_std_array, bl_avg_array = None, None
        full_mask_array = None
        return err

    def kill(self):
        '''
        Set the kill flag.
        '''
        self.killed = True

    def run(self):
        '''
        Run function for composites worker object.
        '''
        ret_tuple = None
        QgsMessageLog.logMessage(u'Starting composites thread',
                                 level=Qgis.Info)
        try:
            if self.analysis != u'Average':
                self.step =\
                    int(90.0 / (len(self.baseline_years_list)) + 5)
            else:
                self.step =\
                    int(90.0 / (len(self.comp1_years_list)) + 1)

            if self.step == 0:
                self.step = 1
            self.curr_progress = 0
            self.progress.emit(int(self.curr_progress))

            # Step 1 - calculate the required seasonal sums
            err = self.calc_seasonal_sums()
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError

            self.progress.emit(int(self.curr_progress))
            QgsMessageLog.logMessage(u'Seasonal sums complete',
                                     level=Qgis.Info)

            if self.analysis != u'Average':
                err = self.calc_average(self.baseline_years_list,
                                        self.baseline_avg)
                if self.killed is True:
                    raise KeyboardInterrupt
                if err is True:
                    raise RuntimeError
                QgsMessageLog.logMessage(u'Baseline average complete',
                                         level=Qgis.Info)
            self.progress.emit(int(self.curr_progress))
            err = self.calc_average(self.comp1_years_list,
                                    self.comp1_dst_file)
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError
            self.progress.emit(int(self.curr_progress))
            if self.comp2_years_list:
                err = self.calc_average(self.comp2_years_list,
                                        self.comp2_dst_file)
                if self.killed is True:
                    raise KeyboardInterrupt
                if err is True:
                    raise RuntimeError
                QgsMessageLog.logMessage(u'Composite 2 avg complete',
                                         level=Qgis.Info)
            self.progress.emit(int(self.curr_progress))
            if self.analysis == self.comp_analysis_list[1]:
                err = self.calc_pct_of_average()
            elif self.analysis == self.comp_analysis_list[2]:
                err = self.calc_anomaly()
            elif self.analysis == self.comp_analysis_list[3]:
                err = self.calc_all_available_years_std_dev_for_period()
                if err is True:
                    raise RuntimeError
                err = self.calc_standardized_anomaly()
            if self.killed is True:
                raise KeyboardInterrupt
            if err is True:
                raise RuntimeError
            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Composites complete")
        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, u"Composites aborted by user")
        # forward any execeptions upstream
        except BaseException as exc:
            self.error.emit(exc, u'Unspecified error in Composites')
        self.finished.emit(ret_tuple)

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
