'''
/***************************************************************************
Name		 : interpolate_statins_worker.py
Description  : BASIICS Interpolate Stations for Fews_Tools Plugin
This code adapted from: https://snorfalorpagus.net/blog/2013/12/07/
                             multithreading-in-qgis-python-plugins/
copyright    : (C) 2020-2023 by FEWS
email        : minxuansun@contractor.usgs.gov
created:     : 02/12/2020
last modified: 06/08/2020 cholen - Remove tab from stats prints
               06/18/2020 cholen - Replaced sample_dates with a bat_dic element
               06/23/2020 cholen - Get cumulative lists outside of loop
               07/13/2020 cholen - Add cell size adj to raster outputs
               07/14/2020 cholen - Adjust error
               08/29/2020 cholen - Only display the last 3 periods
               10/23/2020 cholen - Remove masking and process ds extents
               12/03/2020 cholen - Handle os error
               09/29/2021 cholen - Update to use delimiter in stats headings
               01/18/2022 cholen - New gdal utils, refactor.
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import os

from PyQt5 import QtCore
from qgis.core import QgsMessageLog, Qgis

from fews_tools.utilities import basiics_batch_utilities as b_util
from fews_tools.utilities import geoclim_utilities as util
from fews_tools.utilities import geoclim_qgs_utilities as qgs_util


class InterpolateStationsWorker(QtCore.QObject):
    '''
    Class for BASIICS Interpolate stations
    '''
    def __init__(self, bat_dic, wrksp_setup):
        '''
        init
        params(dic) - bat_dic
        params(dic) - wrksp_setup
        '''
        QtCore.QObject.__init__(self)
        self.bat_dic = bat_dic
        self.wrksp_setup = wrksp_setup
        self.stn_dic = None
        self.killed = False
        self.err_msg = ''
        self.stn_2_stn_list = []
        self.stn_2_pixel_list = []
        self.region_stns_filename = ''
        self.region_data_filename = ''
        self.split_csv_file_list = None
        self.ds_params = None
        self.step = 5.0
        self.curr_progress = 0.0

    def __write_cumulative_stats__(self, sample_date):
        '''
        Function to print out stats results to text file.  For each staton
        location this prints the name, location, value, BASIICS value,
        and Crossvalidated BASIICS value.
        Each sample date output is appended to the output file.
        params(string) - sample_date -  Sample date.
        '''
        spc = self.bat_dic['delimiter']
        try:
            with open(self.bat_dic['stats_out_file'], 'a') as stats:
                # print column headings must already be printed
                s_date = sample_date + self.bat_dic['delimiter']

                for key in self.stn_dic:
                    stats.write(
                        key + spc + s_date +
                        str(round(self.stn_dic[key]['Longitude'], 3)) + spc +
                        str(round(self.stn_dic[key]['Latitude'], 3)) + spc +
                        str(round(self.stn_dic[key]['Stn_val'], 2)) + spc +
                        str(round(self.stn_dic[key]['Intrpltd_stn_val'], 2)) +
                        spc + str(round(
                            self.stn_dic[key]['Xvalidated_stn_val'], 2)) +
                        '\n')
        except BaseException:
            QgsMessageLog.logMessage(
                'Exception - __write_cumulative_stats__ failed',
                level=Qgis.Critical)

    def run(self):
        '''
        run method
        '''
        ret_tuple = None
        try:
            QgsMessageLog.logMessage('Beginning interpolate stations',
                                     level=Qgis.Info)

            # figure progress step for bar
            self.step = 65.0 / len(self.bat_dic['good_files_list'])
            self.curr_progress = 0
            # only one cointertype available
            if self.bat_dic['co_interp_type'] != 'ratio_and_anom':
                self.bat_dic['co_interp_type'] = 'ratio_and_anom'

            self.ds_params = qgs_util.extract_raster_file_params(
                self.bat_dic['good_files_list'][0])
            # Step 1 - extract regions and data within selected extents.
            self.region_stns_filename, self.region_data_filename, err =\
                b_util.get_stations(self.bat_dic, self.ds_params)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()

            # Step 2 - find closest stations to each station for lookup
            # uses the region file created in Step 1
            self.stn_2_stn_list, err =\
                b_util.get_station_2_station_list(
                    self.bat_dic, self.ds_params,
                    self.region_stns_filename)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()

            # Step 3 - find the closest stations to pixels for look up.
            # uses the region file created in Step 1
            self.stn_2_pixel_list, err = b_util.get_station_2_pixel_list(
                self.bat_dic, self.ds_params, self.region_stns_filename)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()

            # Step 4 - split up the region data file into one for each date
            self.split_csv_file_list, err = b_util.split_station_file(
                    self.bat_dic['curr_output_path'], self.bat_dic,
                    self.region_data_filename)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()

            # Step 5 get fuzzy distance
            b_util.get_fuzzy_dist(self.bat_dic, self.ds_params)

            cum_interp_vals = []
            cum_xval_vals = []

            # open with 'w' so any pre-existing file gets overwritten
            with open(self.bat_dic['stats_out_file'], 'w') as stats:
                stats.write("Name" + self.bat_dic["delimiter"] +
                            "FileName" + self.bat_dic["delimiter"] +
                            "Long" + self.bat_dic["delimiter"] +
                            "Lat" + self.bat_dic["delimiter"] +
                            "StnVal" + self.bat_dic["delimiter"] +
                            "IntrpStnVal" + self.bat_dic["delimiter"] +
                            "XValidatedIntrpStnVal\n")

            # Step 6 start sample date loop
            for sample in self.bat_dic['dates_list']:
                # Step 6-1
                name_base = self.bat_dic['output_prefix'] + sample
                # Step 6-2 verify that sample date csv file exists
                sample_date_csv_filename, err = b_util.get_csv(
                    self.split_csv_file_list, sample)

                # Step 6-4 Start filling in station dictionary
                self.stn_dic = b_util.build_station_dic(
                    self.bat_dic, sample_date_csv_filename,
                    self.bat_dic['good_files_list'][0])
                # handle cases where no valid data exists for date
                if not self.stn_dic:
                    continue
                # Step 6-5 Interpolate station values using 'ordinary' type
                b_util.cointerpolate_stations_idw(
                    self.bat_dic, self.stn_dic,
                    self.bat_dic['interp_type'],
                    self.stn_2_stn_list,
                    'Stn_val', 'Intrpltd_stn_val', 'Xvalidated_stn_val')
                QgsMessageLog.logMessage('Interpolated station values',
                                         level=Qgis.Info)

                # Step 6-6 Create station shapefile for date
                dst_shp_filename =\
                    os.path.join(self.bat_dic['curr_output_path'],
                                 name_base + '_stn.shp')
                qgs_util.create_station_shapefile(
                    self.stn_dic, dst_shp_filename)
                QgsMessageLog.logMessage(
                    dst_shp_filename + ' shapefile complete',
                    level=Qgis.Info)
                # Step 6-7 Output initial stats if flagged
                if self.bat_dic['output_stats_flag']:
                    dst_stats1_file =\
                        os.path.join(self.bat_dic['curr_output_path'],
                                     name_base + '.stat.txt')
                    try:
                        if os.path.exists(dst_stats1_file):
                            os.remove(dst_stats1_file)
                    except OSError:  # message use and re-raise the error
                        QgsMessageLog.logMessage(
                            ('File open in another process:  ' +
                             dst_stats1_file),
                            level=Qgis.Critical)
                        raise OSError
                    stats_dic =\
                        util.weighted_least_squares_simple_linear_dic(
                            self.stn_dic, 'Stn_val', 'Intrpltd_stn_val',
                            None, int(self.bat_dic['stn_missing_val']))

                    title_l = ('Cross-validated statistical analysis '
                               'comparing IDW-interpolated value with '
                               'station included (X) and station excluded '
                               '(Y) for ' + dst_stats1_file)
                    b_util.print_least_squares_stats(
                        stats_dic, dst_stats1_file, title_l, False)

                    b_util.print_station_info(
                        self.bat_dic, self.stn_dic, dst_stats1_file)

                # Step 6-8 - interpolate the  array
                intrpltd_array =\
                    b_util.interpolate_array_idw(
                        self.bat_dic, self.ds_params,
                        self.stn_dic, self.bat_dic['interp_type'],
                        self.stn_2_pixel_list, 'Stn_val')

                # Step 6-9 Plot cross validation
                dst_jpg_file =\
                    os.path.join(self.bat_dic['curr_output_path'],
                                 name_base + '.crossval_graph.jpg')
                title = ('Cross Validation:ppt' + sample)
                b_util.plot_graph_dic(
                    self.stn_dic,
                    'Intrpltd_stn_val', 'Xvalidated_stn_val',
                    'IDW Interpolated Station Value',
                    'CrossValidated IDW Interpolated Station Value',
                    title, dst_jpg_file)

                # Step 6-11 - output stats if flagged
                if self.bat_dic['output_stats_flag']:
                    self.__write_cumulative_stats__(sample)
                    stats_dic =\
                        util.weighted_least_squares_simple_linear_dic(
                            self.stn_dic, 'Intrpltd_stn_val',
                            'Xvalidated_stn_val', None,
                            self.bat_dic['stn_missing_val'])

                    title_l = ('Cross-validated statistical analysis '
                               'comparing interpolated station value (X) '
                               'with Cross-validated BASIICS value (Y) '
                               'for ' + dst_stats1_file)
                    b_util.print_least_squares_stats(
                        stats_dic, dst_stats1_file, title_l, True)

                temp_path = self.wrksp_setup.get_temp_data_path()
                b_util.create_output_raster_file(
                    self.bat_dic, self.ds_params, intrpltd_array,
                    temp_path, name_base)

                QgsMessageLog.logMessage(name_base + 'raster file complete',
                                         level=Qgis.Info)
                QgsMessageLog.logMessage(sample + ' loop complete',
                                         level=Qgis.Info)
                if self.killed is True:
                    raise KeyboardInterrupt
                self.update_progress()
            # printout cumulative graphics and stats
            cum_interp_vals = []
            cum_xval_vals = []
            data_list = b_util.read_csv_file(
                self.bat_dic['stats_out_file'],
                self.bat_dic['delimiter'], 1)
            for entry in data_list:
                cum_interp_vals.append(float(entry[5]))
                cum_xval_vals.append(float(entry[6]))
            del data_list

            if not cum_interp_vals:
                QgsMessageLog.logMessage('No valid stations found',
                                         level=Qgis.Info)
                raise IOError

            # Step 7 - print out cumulative cross validation jpg
            dst_jpg_file =\
                os.path.join(self.bat_dic['curr_output_path'],
                             'stats.crossval_graph.jpg')
            b_util.plot_graph_list(
                cum_interp_vals, cum_xval_vals,
                'IDW Interpolated Station Value',
                'Cross Validated IDW Interpolated Removed',
                'Cross Validation:stats', dst_jpg_file)
            # Step 8 - print final stats to cumulative stats file
            stats_dic =\
                util.weighted_least_squares_simple_linear_list(
                    cum_interp_vals, cum_xval_vals, None,
                    self.bat_dic['stn_missing_val'])

            title_l = ('Crossvalidated Comparison statistics for ' +
                       'Interpolated Station Value (X) against ' +
                       'Cross-Validated Interpolated Value(Y)')
            b_util.print_least_squares_stats(
                stats_dic, self.bat_dic['stats_out_file'], title_l, True)
            # cleanup
            try:
                os.remove(self.region_stns_filename)
                os.remove(self.region_data_filename)
            except OSError:
                pass
            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Interpolating stations complete")
        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, u"Interpolating stations aborted by user")
        # forward any execeptions upstream
        except BaseException as exc:
            self.error.emit(exc, u"Unspecified error in Interpolating stations")
        self.finished.emit(ret_tuple)

    def kill(self):
        '''
        Kill method
        '''
        self.killed = True

    def update_progress(self):
        '''
        Helper for progress bar updates
        '''
        self.curr_progress += self.step
        if self.curr_progress > 90:
            self.curr_progress = 90
        self.progress.emit(int(self.curr_progress))

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
