'''
/***************************************************************************
Name		     : validate_rfe_worker.py
Description      : BASIICS Validate RFE for Fews_Tools Plugin
This code adapted from: https://snorfalorpagus.net/blog/2013/12/07/
                             multithreading-in-qgis-python-plugins/
copyright   : (C) 2020-2023 by FEWS
email       : minxuansun@contractor.usgs.gov
Created:    : 02/07/2020
Modified    : 06/08/2020 cholen - Remove tab from stats prints
              06/18/2020 cholen - Replaced sample_dates with bat_dic element
              06/23/2020 cholen - Get cumulative lists outside of loop
              07/13/2020 cholen - Add cell size adj to raster outputs
              07/14/2020 cholen - Adjust error
              08/29/2020 cholen - Only display the last 3 periods
              09/29/2020 cholen - Remove redundant calls to get grid vals
              10/23/2020 cholen - Remove masking and process ds extents
              12/03/2020 cholen - Handle os error
              09/29/2021 cholen - Update to use delimiter in stats headings
              01/18/2022 cholen - New gdal utils, refactor.
              02/23/2022 cholen - New gdal utils function used.
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import os

from PyQt5 import QtCore
from qgis.core import QgsMessageLog, Qgis

from fews_tools.utilities import basiics_batch_utilities as b_util
from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geoclim_utilities as util
from fews_tools.utilities import geoclim_qgs_utilities as qgs_util


class ValidateRFEWorker(QtCore.QObject):
    '''
    Class for BASIICS Validating RFE
    '''
    def __init__(self, bat_dic, wrksp_setup):
        '''
        init
        params(dic) - batch_dic
        params(dic) - wrksp_setup
        '''
        QtCore.QObject.__init__(self)
        self.bat_dic = bat_dic
        self.wrksp_setup = wrksp_setup
        self.region_stns_filename = ''
        self.region_data_filename = ''
        self.stn_2_stn_list = None
        self.split_csv_file_list = None
        self.dst_stats1_file = ''
        self.stn_dic = None
        self.killed = False
        # make sure bat_dic members are correct for validation
        self.bat_dic['wt_power'] = 2.0
        self.bat_dic['min_stations'] = 1
        self.bat_dic['max_stations'] = 10
        self.bat_dic['search_radius'] = 500
        self.bat_dic['fuzz_factor'] = 1
        self.ds_params = None
        self.step = 5.0
        self.curr_progress = 0.0

    def __write_cumulative_stats__(self, sample_date):
        '''
        Function to print out stats results to text file.  For each staton
        location this prints the name, location, value, grid value, snd
        BASIICS values.
        Each sample date output is appended to the output file.
        params(string) - sample_date -  Sample date.
        '''
        spc = self.bat_dic['delimiter']
        try:
            with open(self.bat_dic['stats_out_file'], 'a') as stats:
                # print column headings must already be printed
                s_date = sample_date + self.bat_dic['delimiter']

                for key in self.stn_dic:
                    stats.write(
                        key + spc + s_date +
                        str(round(self.stn_dic[key]['Longitude'], 3)) + spc +
                        str(round(self.stn_dic[key]['Latitude'], 3)) + spc +
                        str(round(self.stn_dic[key]['Stn_val'], 2)) + spc +
                        str(round(self.stn_dic[key]['Intrpltd_stn_val'], 2)) +
                        spc + str(round(self.stn_dic[key]['Grid_val'], 2)) +
                        '\n')
        except BaseException:
            QgsMessageLog.logMessage(
                'Exception - __write_cumulative_stats__ failed',
                level=Qgis.Critical)

    def run(self):
        '''
        run method
        '''
        ret_tuple = None
        try:
            QgsMessageLog.logMessage('Begin ValidateRFE',
                                     level=Qgis.Info)
            # figure progress step for bar
            self.step = 65.0 / len(self.bat_dic['good_files_list'])
            self.curr_progress = 0
            self.ds_params = qgs_util.extract_raster_file_params(
                self.bat_dic['good_files_list'][0])
            # Step 1 - extract regions and data within selected extents.
            self.region_stns_filename, self.region_data_filename, err =\
                b_util.get_stations(self.bat_dic, self.ds_params)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()

            # Step 2 - find closest stations to each station for lookup
            # uses the region file created in Step 1
            self.stn_2_stn_list, err =\
                b_util.get_station_2_station_list(
                    self.bat_dic, self.ds_params,
                    self.region_stns_filename)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()

            # Step 3 - split up the region data file into one for each date
            self.split_csv_file_list, err =\
                b_util.split_station_file(
                    self.bat_dic['curr_output_path'], self.bat_dic,
                    self.region_data_filename)
            if err is True:
                raise RuntimeError
            if self.killed is True:
                raise KeyboardInterrupt
            self.update_progress()
            # Step 4 get fuzzy distance
            b_util.get_fuzzy_dist(self.bat_dic, self.ds_params)

            # open with 'w' so any pre-existing file gets overwritten
            with open(self.bat_dic['stats_out_file'], 'w') as stats:
                stats.write("Name" + self.bat_dic["delimiter"] +
                            "FileName" + self.bat_dic["delimiter"] +
                            "Long" + self.bat_dic["delimiter"] +
                            "Lat" + self.bat_dic["delimiter"] +
                            "StnVal" + self.bat_dic["delimiter"] +
                            "InterpAtStnVal" + self.bat_dic["delimiter"] +
                            "GridVal\n")
            reg_exp = b_util.get_datestring_reg_expression(self.bat_dic)
            # Step 5 - loop through sample dates
            count = 0
            for sample in self.bat_dic['dates_list']:
                # Step 5-1
                name_base = self.bat_dic['output_prefix'] + sample
                # Step 5-2 verify that sample date csv file exists
                sample_date_csv_filename, err = b_util.get_csv(
                    self.split_csv_file_list, sample)
                if err is True:
                    raise RuntimeError
                # Step 5-3 verify that grid file exists
                rain_grid_filename, err = b_util.get_grid(
                    self.bat_dic, sample, reg_exp)
                if err is True:
                    raise RuntimeError
                # Step 5-4 Start filling in station dictionary
                self.stn_dic = b_util.build_station_dic(
                    self.bat_dic, sample_date_csv_filename,
                    self.bat_dic['good_files_list'][count])
                count += 1
                # handle cases where no valid data exists for date
                if not self.stn_dic:
                    continue
                # Step 5-5 Interpolate station values using 'ordinary' type
                b_util.cointerpolate_stations_idw(
                    self.bat_dic, self.stn_dic,
                    'Ordinary', self.stn_2_stn_list,
                    'Stn_val', 'Intrpltd_stn_val', 'Xvalidated_stn_val')
                QgsMessageLog.logMessage('Interpolated station values',
                                         level=Qgis.Info)

                # Step 5-7 remove any of the stations that show missing val
                bad_vals = [b_util.MINSHORT,
                            self.bat_dic['ds_dic']['DATAMISSINGVALUE']]
                b_util.remove_nodata_keys(
                    self.stn_dic, 'Stn_val', bad_vals)
                # Step 5-8 Create station shapefile for date
                dst_shp_filename =\
                    os.path.join(self.bat_dic['curr_output_path'],
                                 name_base + '_stn.shp')
                qgs_util.create_station_shapefile(
                    self.stn_dic, dst_shp_filename)
                QgsMessageLog.logMessage(
                    dst_shp_filename + ' shapefile complete',
                    level=Qgis.Info)

                # Step 5-9 Output initial stats
                self.dst_stats1_file =\
                    os.path.join(self.bat_dic['curr_output_path'],
                                 name_base + '.csv')
                try:
                    if os.path.exists(self.dst_stats1_file):
                        os.remove(self.dst_stats1_file)
                except OSError:  # message use and re-raise the error
                    QgsMessageLog.logMessage(
                        ('File open in another process:  ' +
                         self.dst_stats1_file),
                        level=Qgis.Critical)
                    raise OSError
                b_util.print_station_info(
                    self.bat_dic, self.stn_dic, self.dst_stats1_file)

                stats_dic =\
                    util.weighted_least_squares_simple_linear_dic(
                        self.stn_dic, 'Stn_val', 'Grid_val', None,
                        int(self.bat_dic['stn_missing_val']))

                title_l = ('Statistical analysis comparing '
                           'Measured station values (X) with '
                           'background grid values (Y) for ' +
                           self.dst_stats1_file)
                b_util.print_least_squares_stats(
                    stats_dic, self.dst_stats1_file, title_l, True)

                stats_dic =\
                    util.weighted_least_squares_simple_linear_dic(
                        self.stn_dic, 'Intrpltd_stn_val', 'Grid_val', None,
                        int(self.bat_dic['stn_missing_val']))

                title_l = ('Statistical analysis comparing '
                           'Interpolated station values (X) with '
                           'background grid values (Y) for ' +
                           self.dst_stats1_file)
                b_util.print_least_squares_stats(
                    stats_dic, self.dst_stats1_file, title_l, True)
                # Step 5-10 Plot stngrid graph
                dst_jpg_file =\
                    os.path.join(self.bat_dic['curr_output_path'],
                                 name_base + '.stngrid_graph.jpg')
                title = ('Comparison between Station and Grid:ppt' +
                         sample)
                b_util.plot_graph_dic(
                    self.stn_dic, 'Stn_val', 'Grid_val',
                    'Station Value', 'Original Grid Value',
                    title, dst_jpg_file)
                # Step 5-11 - add to cumulative stats lists moved
                #             to outside loop

                # Step 5-12 - print stats validation
                self.__write_cumulative_stats__(sample)
                # need to clip the rainfall file to a temp file so it's
                # available for the jpg output
                longitude_val_list =\
                    [val['Longitude'] for val in self.stn_dic.values()]
                latitude_val_list =\
                    [val['Latitude'] for val in self.stn_dic.values()]
                dst_file = os.path.join(
                    self.bat_dic['curr_output_path'],
                    name_base + self.bat_dic['ds_dic']['DATASUFFIX'])
                bbox = {"MinLongitude": min(longitude_val_list),
                        "MaxLatitude": max(latitude_val_list),
                        "MaxLongitude": max(longitude_val_list),
                        "MinLatitude": min(latitude_val_list)}
                # add 1 to the extents to get just past stns extents
                g_util.clip_raster_to_bbox(
                    rain_grid_filename, dst_file, bbox,
                    str(self.bat_dic['ds_dic']['DATAMISSINGVALUE']), 1)
                QgsMessageLog.logMessage(dst_file + ' created',
                                         level=Qgis.Info)
                # get rid of the .aux.xml file, not needed
                try:
                    aux_file = dst_file + '.aux.xml'
                    if os.path.isfile(aux_file):
                        os.remove(aux_file)
                except OSError:
                    pass
                QgsMessageLog.logMessage(sample + ' loop complete',
                                         level=Qgis.Info)
                if self.killed is True:
                    raise KeyboardInterrupt
                self.update_progress()
            self.progress.emit(90)

            # printout cumulative graphics and stats
            cum_interp_vals = []
            cum_grid_vals = []
            cum_stn_vals = []

            data_list = b_util.read_csv_file(
                self.bat_dic['stats_out_file'],
                self.bat_dic['delimiter'], 1)

            for entry in data_list:
                cum_stn_vals.append(float(entry[4]))
                cum_interp_vals.append(float(entry[5]))
                cum_grid_vals.append(float(entry[6]))
            del data_list

            if not cum_interp_vals:
                QgsMessageLog.logMessage('No valid stations found',
                                         level=Qgis.Info)
                raise IOError

            dst_jpg_file = os.path.join(self.bat_dic['curr_output_path'],
                                        'stats.stngrid_graph.jpg')
            b_util.plot_graph_list(
                cum_stn_vals, cum_grid_vals,
                'Point Station Value', 'RFE Grid Value',
                'Comparison between Point Station and Grid:stats',
                dst_jpg_file)

            dst_jpg_file = os.path.join(self.bat_dic['curr_output_path'],
                                        'stats.intstngrid_graph.jpg')

            b_util.plot_graph_list(
                cum_interp_vals, cum_grid_vals,
                'At Station Interpolated Value', 'RFE Grid Value',
                ('Comparison between At Station Interpolated ' +
                 'and Grid:stats'),
                dst_jpg_file)

            # print final stats to cumulative stats file
            stats_dic =\
                util.weighted_least_squares_simple_linear_list(
                    cum_stn_vals, cum_grid_vals, None,
                    self.bat_dic['stn_missing_val'])

            title_l = ('Comparison statistics for Measured Station ' +
                       'Value (X) against Original Grid Value (Y)\n')
            b_util.print_least_squares_stats(
                stats_dic, self.bat_dic['stats_out_file'], title_l, True)

            stats_dic =\
                util.weighted_least_squares_simple_linear_list(
                    cum_interp_vals, cum_grid_vals, None,
                    self.bat_dic['stn_missing_val'])

            title_l = ('Crossvalidated Comparison statistics for ' +
                       'Interpolated Station Value (X) against ' +
                       'Original Grid Value (Y)\n')
            b_util.print_least_squares_stats(
                stats_dic, self.bat_dic['stats_out_file'], title_l, True)
            if self.killed is True:
                raise KeyboardInterrupt
            # cleanup
            try:
                os.remove(self.region_stns_filename)
                os.remove(self.region_data_filename)
            except OSError:
                pass
            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Validate RFE complete")
        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, u"Validate RFE aborted by user")
        # forward any execeptions upstream
        except BaseException as exc:
            self.error.emit(exc, u"Unspecified error in Validate RFE")
        self.finished.emit(ret_tuple)

    def kill(self):
        '''
        Kill method
        '''
        self.killed = True

    def update_progress(self):
        '''
        Helper for progress bar updates
        '''
        self.curr_progress += self.step
        if self.curr_progress > 90:
            self.curr_progress = 90
        self.progress.emit(int(self.curr_progress))

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
