'''
/***************************************************************************
Name          :  extract_stats_worker.py
Description   :  Extract statistics worker class
copyright     :  (C) 2022 - 2023 by FEWS
email         :  minxuansun@contractor.usgs.gov
Created       :  02/16/2022 - cholen
Modified      :  03/25/2022 - cholen - Add tiff support
                 07/28/2023 - jhowton - Added update CSV

 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import os
import csv
import shutil

from PyQt5 import QtCore

from qgis.analysis import QgsZonalStatistics
from qgis.core import QgsMessageLog, Qgis
from qgis.core import QgsVectorLayer, QgsRasterLayer, QgsWkbTypes


from fews_tools import fews_tools_config as config
from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geoclim_qgs_utilities as qgs_util
from fews_tools.utilities import geoclim_utilities as util


class ExtractStatsWorker(QtCore.QObject):
    '''
    Worker class for the extract stats tool
    '''

    ROUNDING_POSITION = 5  # Round the result to 5th decimals

    Z_STATS_TYPES_DIC = {'Average': ' mean',
                         'Count': ' count',
                         'Maximum': ' max',
                         'Minimum': ' min',
                         'Median': ' median',
                         'Range': ' range',
                         'StdDev': ' stdev',
                         'Sum': ' sum'}


    def __init__(self,
                 orig_vector_path,
                 unique_field_name,
                 raster_dic,
                 temp_path,
                 output_file,
                 stat_type,
                 nd_val,
                 use_existing_csv,
                 existing_csv_path):
        QtCore.QObject.__init__(self)
        self.orig_vector_path = orig_vector_path
        self.unique_field_name = unique_field_name
        self.raster_dic = raster_dic
        self.temp_path = temp_path
        self.output_file = output_file
        self.analysis = stat_type
        self.nd_val = nd_val
        self.use_existing_csv = use_existing_csv
        self.existing_csv_path = existing_csv_path
        self.temp_shp = ""
        self.curr_progress = 0
        self.stat_type = QgsZonalStatistics.Mean
        self.killed = False
        self.step = 1

    def run(self):
        '''
        Run function for extract stats worker object.
        '''
        ret_tuple = None
        try:
            QgsMessageLog.logMessage(u"Begin extract stats...",
                                     level=Qgis.Info)

            rst_count = len(self.raster_dic.keys())
            self.step = (90.0 / rst_count)

            self.curr_progress = 0
            self.progress.emit(int(self.curr_progress))
            self.create_temp_vector_file()
            point_features = False
            geom_single_type = False
            vector_layer = QgsVectorLayer(self.temp_shp, "Shp", config.OGR)
            v_features = vector_layer.getFeatures()

            # Prepare lists to hold the final data
            features = list(v_features)
            feature_keys = [f[self.unique_field_name] for f in features]
            results = {key: [] for key in feature_keys}
            col_headings = [(val["prefix"] + val["attr"]) for val in self.raster_dic.values()]
            col_headings.insert(0, "Feature")

            # check the geom type of the first feature
            # requires all features are the same type
            for feature in v_features:
                geom = feature.geometry()
                if geom.type() == QgsWkbTypes.PointGeometry:
                    point_features = True
                    geom_single_type = QgsWkbTypes.isSingleType(geom.wkbType())
                    if geom_single_type is False:
                        raise ValueError
                break

            for val in self.raster_dic.values():
                src_dir = os.path.split(os.path.splitext(val["path"])[0])[0]
                temp_filename = val["path"].replace(src_dir, self.temp_path)
                g_util.set_nodata(val["path"], temp_filename, str(self.nd_val))
                raster_layer = QgsRasterLayer(temp_filename)

                # Keep track of the fields before calculating stats
                before_fields = set(f.name() for f in vector_layer.fields())

                if point_features is True:
                    # for point features, get the pixel value at each point
                    # as the "stat", doesn't matter what stat is selected
                    qgs_util.add_grid_value_to_shapefile(
                        vector_layer, raster_layer, val["attr"])
                else:
                    self.get_analysis()
                    zone_stats = QgsZonalStatistics(vector_layer, raster_layer,
                                                    " ", 1, self.stat_type)
                    zone_stats.calculateStatistics(None)

                # Find the new field added
                after_fields = set(f.name() for f in vector_layer.fields())
                new_fields = list(after_fields - before_fields)
                stat_field = new_fields[0]

                # Store stat for each feature by key
                for f in vector_layer.getFeatures():
                    feature_key = f[self.unique_field_name]
                    results[feature_key].append(round(f[stat_field], self.ROUNDING_POSITION))

                raster_layer = None
                util.remove_raster_file(temp_filename)

                # Clean up the field from the vector layer
                vector_layer.dataProvider().deleteAttributes([
                    vector_layer.fields().indexOf(stat_field)])
                vector_layer.updateFields()

                self.update_progress()

            # Create the output array
            attribute_list = [tuple(col_headings)]
            for feature in list(vector_layer.getFeatures()):
                key = feature[self.unique_field_name]
                row = [key] + results[key]
                attribute_list.append(tuple(row))

            if self.use_existing_csv:
                existing_csv_data = []

                # Read in the existing CSV
                with open(self.existing_csv_path, 'r', newline='') as csvfile:
                    reader = csv.reader(csvfile)
                    for row in reader:
                        existing_csv_data.append(row)

                # Get the headers for the CSV and the new attribute list
                headers_existing = existing_csv_data[0]
                headers_new = attribute_list[0]

                # Create a dict of headers to their corresponding data
                # from attribute_list
                attribute_dict = {header: [] for header in headers_new}
                for row in attribute_list[1:]:
                    for col_idx, value in enumerate(row):
                        attribute_dict[headers_new[col_idx]].append(value)

                for header in headers_new:
                    # If the header already exists in attribute_list then
                    # update the values
                    if header in headers_existing:
                        col_idx = headers_existing.index(header)
                        for row_idx in range(1, len(existing_csv_data)):
                            existing_csv_data[row_idx][col_idx] = \
                                attribute_dict[header][row_idx - 1]

                    # If the header does not exist in attribute_list then
                    # append the new col
                    else:
                        existing_csv_data[0].append(header)
                        for row_idx in range(1, len(existing_csv_data)):
                            existing_csv_data[row_idx].append(
                                attribute_dict[header][row_idx - 1])

                # Save the new combination of data to attribute_list
                # to be written to disk
                attribute_list = [tuple(row) for row in existing_csv_data]

            # writes selected data to a csv file, will overwrite any existing
            with open(self.output_file, 'w', newline='') as csv_file:
                csv_writer = csv.writer(csv_file)
                for row in attribute_list:
                    csv_writer.writerow(row)

            QgsMessageLog.logMessage(
                u"Completed:  " + self.output_file,
                level=Qgis.Info)
            if self.killed is True:
                raise KeyboardInterrupt
            if self.killed is False:
                self.progress.emit(100)
                ret_tuple = (0, "Extract Stats complete")
        # exit with appropriate message on killed (KeyboardInterrupt)
        except KeyboardInterrupt:
            self.progress.emit(0)
            ret_tuple = (0, u"Extract Stats aborted by user")
        except ValueError:
            self.progress.emit(0)
            ret_tuple = (0, u"Shapefile geometry problem")
        # forward any exceptions upstream
        except BaseException as exc:
            self.error.emit(exc, u"Unspecified error in Extract Stats")
        vector_layer = None
        raster_layer = None
        self.finished.emit(ret_tuple)

    def test_fieldname(self, field_name):
        '''
        Test field name
        '''
        ret_val = False
        for val in self.Z_STATS_TYPES_DIC.values():
            if field_name.startswith(val):
                ret_val = True
                break
        return ret_val

    def create_temp_vector_file(self):
        '''
        Set up temporary vector paths for extraction
        '''
        file_name = os.path.basename(os.path.splitext(
            self.orig_vector_path)[0])
        temp_base = os.path.join(self.temp_path, file_name)

        self.temp_shp = temp_base + config.SHP_SUFFIX
        temp_dbf = temp_base + config.DBF_SUFFIX
        temp_shx = temp_base + config.SHX_SUFFIX

        src_dbf = self.orig_vector_path.replace(
            config.SHP_SUFFIX, config.DBF_SUFFIX)
        src_shx = self.orig_vector_path.replace(
            config.SHP_SUFFIX, config.SHX_SUFFIX)

        # create new copy of shapefile
        shutil.copy(self.orig_vector_path, self.temp_shp)
        shutil.copy(src_dbf, temp_dbf)
        shutil.copy(src_shx, temp_shx)

    def get_analysis(self):
        '''
        Convert gui passed value to QgsZonalStatistics enum
        '''
        if self.analysis == 'Average':
            self.stat_type = QgsZonalStatistics.Mean
        elif self.analysis == 'Count':
            self.stat_type = QgsZonalStatistics.Count
        elif self.analysis == 'Maximum':
            self.stat_type = QgsZonalStatistics.Max
        elif self.analysis == 'Minimum':
            self.stat_type = QgsZonalStatistics.Min
        elif self.analysis == 'Median':
            self.stat_type = QgsZonalStatistics.Median
        elif self.analysis == 'Range':
            self.stat_type = QgsZonalStatistics.Range
        elif self.analysis == 'StdDev':
            self.stat_type = QgsZonalStatistics.StDev
        else:  # if self.analysis == 'Sum':
            self.stat_type = QgsZonalStatistics.Sum

    def kill(self):
        '''
        Set the kill flag.
        '''
        self.killed = True

    def update_progress(self):
        '''
        Helper for progress bar updates
        '''
        self.curr_progress += self.step
        if self.curr_progress > 90:
            self.curr_progress = 90
        self.progress.emit(int(self.curr_progress))

    finished = QtCore.pyqtSignal(object)

    error = QtCore.pyqtSignal(Exception, str)

    progress = QtCore.pyqtSignal(int)
