'''
/***************************************************************************
Name	   :  FEWSTools plugin
Description:  GeoCLIM Qgs Utility Functions for FEWSTools plugin,
              updated from QGIS2
              Do not include gdal type utilities here.
copyright  :  (C) 2019-2023 by FEWS
email      :  minxuansun@contractor.usgs.gov
Created    :  09/30/2019 - CHOLEN
Modified   :  01/07/2020 - cholen - Updated functions
              02/06/2020 - cholen - Added masking functions, config import
              02/20/2020 - cholen - Updated extract shapefile extents
              02/29/2020 - cholen - Change rectangle_equals rounding to 2 plcs
              04/10/2020 - cholen - Added functions to backup and reload map
                                    panel, updated some constants and fix
                                    problem with notify of loaded file.
              06/18/2020 - cholen - Added get_open_files_info
              06/26/2020 - cholen - Updated create_station_shapefile
              06/29/2020 - cholen - Replace deprecated CRS and XFORM_CONTEXT
              06/30/2020 - cholen - Try/except on station file save
              07/09/2020 - cholen - Update color files
              07/14/2020 - cholen - Log on exceptions
              10/13/2020 - cholen - Add display_map_layer
              10/23/2020 - cholen - Tweak display map layer fix reg_dic name
              12/03/2020 - cholen - Handle OSError
              12/23/2020 - cholen - Add translate_datatype and
                                       rst_calc_sum_files_raw
              11/24/2021 - cholen - SCA cleanup, remove util import, move gdal functions
                                       to other util class
              02/10/2022 - cholen - Add add_grid_value_to_shapefile
              04/21/2022 - cholen - Add tiff support
              06/09/2022 - cholen - Rewrite create_station_shapefile
              06/23/2022 - cholen - Fix path for region mask and map files
***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import os

from PyQt5.QtCore import QSettings, QFileInfo, QVariant
from PyQt5.QtWidgets import QMessageBox, QDialog
from qgis.analysis import QgsRasterCalculator, QgsRasterCalculatorEntry
from qgis.core import QgsRasterLayer, QgsVectorLayer, QgsVectorFileWriter
from qgis.core import QgsWkbTypes
from qgis.core import QgsFields, QgsField, QgsRectangle
from qgis.core import QgsProject, QgsPointXY, QgsFeature, QgsGeometry
from qgis.core import QgsCoordinateReferenceSystem
from qgis.core import QgsMessageLog, Qgis
from qgis.utils import iface

from fews_tools import fews_tools_config as config
from fews_tools.models.workspace_setup_model import WorkspaceSetupModel


BAND_NUM = 1

CRS = QgsCoordinateReferenceSystem(config.DEFAULT_CRS)
XFORM_CONTEXT = QgsProject.instance().transformContext()
SAVE_OPTIONS = QgsVectorFileWriter.SaveVectorOptions()
SAVE_OPTIONS.driverName = "ESRI Shapefile"
SAVE_OPTIONS.fileEncoding = "UTF-8"  #"CP1250"  "UTF-8"

QTEMP_PATH = 'c:\\temp_qgis'
TEMP_PROJECT = 'project_temp.qgz'

MAX_FIELDNAME_LEN = 8  # max length of a date format(for shpfile attributes)


settings = QSettings()
# change setting to use the project crs
# otherwise, keept getting dialog to select it.
settings.setValue('/Projections/defaultBehavior', 'useProject')


def check_dataset_vs_region_extents(
        dialog_object, ds_extents, reg_extents, m_flag=True):
    '''
    Checks if region is within dataset extents, can also be used for mask v reg
    params(QgsRectangle) - ds_extents - Dataset extents(or mask extent)
    params(QgsRectangle) - reg_extents - Region extents
    params(boolean) - m_flag - Show message?
    returns(boolean) - ret_val - True if region within dataset, else False
    '''
    if ds_extents.contains(reg_extents) or\
            rectangle_equals(ds_extents, reg_extents):
        ret_val = True
    else:
        ret_val = False
        if m_flag:
            QMessageBox.information(dialog_object,
                                    u'Error!! - Mask vs region mismatch',
                                    u'Mask vs region mismatch!\nExiting!',
                                    QMessageBox.Ok)
            QgsMessageLog.logMessage(u'Mask vs region mismatch',
                                     level=Qgis.Critical)
    return ret_val


def clear_status_msg():
    '''
    Clear the status bar message.
    '''
    iface.mainWindow().statusBar().showMessage(u'')


def add_grid_value_to_shapefile(vector_layer, raster_layer, add_field_name):
    '''
    Add a grid value attribute to a vector layer. This modifies the vector
    layer for the calling function.

    params(string) - src_vector_filename
    params(string) - src_raster_filename
    params(string) - add_field_name
    '''
    # add the needed field
    if len(add_field_name) > MAX_FIELDNAME_LEN:
        add_field_name = add_field_name[(-1 * MAX_FIELDNAME_LEN):]
    if add_field_name not in [field.name() for field in vector_layer.fields()]:
        vector_layer.dataProvider().addAttributes(
                [QgsField(add_field_name, QVariant.Double)])
        vector_layer.updateFields()
    # fill in the attribute for all features
    vector_layer.startEditing()
    for feature in vector_layer.getFeatures():
        geom = feature.geometry()
        g_pt = geom.asPoint()
        val, _ = raster_layer.dataProvider().sample(g_pt, 1)
        feature[add_field_name] = val
        vector_layer.updateFeature(feature)

    vector_layer.commitChanges()  # stops the editing process


def create_station_shapefile(station_dic, dst_filename):
    '''
    Function to create the stations shapefile in BASIICS functionality.
    The attribute columns are hard-coded here to show:
     station id, longitude, latitude, station value,
    interpolated station value and corresponding grid or
    cross validated station value.
    params(dic) - station_dic - Station dictionary
    params(string) - dst_filename - The name of the output file.
    '''
    if os.path.exists(dst_filename):
        QgsVectorFileWriter.deleteShapeFile(dst_filename)
    # define fields for feature attributes. A QgsFields object is needed
    fields = QgsFields()
    fields.append(QgsField("Name", QVariant.String))
    fields.append(QgsField("StnLong", QVariant.Double, 'double', 10, 5))
    fields.append(QgsField("StnLat", QVariant.Double, 'double', 10, 5))
    fields.append(QgsField("StnVal", QVariant.Double, 'double', 10, 5))
    fields.append(QgsField("IntStnVal", QVariant.Double, 'double', 10, 5))
    if 'Grid_val' in station_dic.items():
        fields.append(QgsField("GridVal", QVariant.Double, 'double', 10, 5))
    else:
        fields.append(QgsField("XValidatedStnVal", QVariant.Double,
                               'double', 10, 5))

    writer = QgsVectorFileWriter.create(
        dst_filename,
        fields,
        QgsWkbTypes.Point,
        CRS, XFORM_CONTEXT, SAVE_OPTIONS)

    if writer.hasError() != QgsVectorFileWriter.NoError:
        QgsMessageLog.logMessage("Error when creating shapefile: " +
                                 writer.errorMessage(),
                                 level=Qgis.Critical)

    # add a feature
    for key in station_dic:
        # create the feature
        point = QgsPointXY(station_dic[key]['Longitude'],
                           station_dic[key]['Latitude'])
        feature = QgsFeature()
        feature.setGeometry(QgsGeometry.fromPointXY(point))
        if 'Grid_val' in station_dic.items():
            feature.setAttributes([key,
                                   station_dic[key]['Longitude'],
                                   station_dic[key]['Latitude'],
                                   station_dic[key]['Stn_val'],
                                   station_dic[key]['Intrpltd_stn_val'],
                                   float(station_dic[key]['Grid_val'])])
        else:
            feature.setAttributes([key,
                                   station_dic[key]['Longitude'],
                                   station_dic[key]['Latitude'],
                                   station_dic[key]['Stn_val'],
                                   station_dic[key]['Intrpltd_stn_val'],
                                   station_dic[key]['Xvalidated_stn_val']])
        writer.addFeature(feature)

    # delete the writer to flush features to disk
    del writer


def display_raster_layer(output_file, qml_file):
    '''
    Function to dispaly a raster in the QGIS panel.
    params(string) - output_file - The name of the output file to display
    params(string) - qml_file - The full path of the color file for the display
    '''
    proj = QSettings()
    default =\
        proj.value("/Projections/defaultBehaviour", "prompt", type=str)
    proj.setValue("/Projections/defaultBehaviour", "useProject")
    toc_name = os.path.splitext(os.path.basename(output_file))[0]
    raster_lyr = QgsRasterLayer(output_file, toc_name)
    raster_lyr.setCrs(CRS)
    proj.setValue("/Projections/defaultBehaviour", default)
    raster_lyr.loadNamedStyle(qml_file)
    QgsProject.instance().addMapLayer(raster_lyr)
    iface.layerTreeView().collapseAllNodes()


def display_map_layer(reg_dic: dict, open_file_names: list):
    '''
    Function to display the map layer zoomed to the region extents in the
    QGIS panel. No color is given as the MAP_COLOR_FILE is always used. There
    can be more than one map.
    params(dic) - reg_dic - The region dictionary
    params(list) - open_file_names - The existing files in QGIS layer panel
    '''
    wrksp_setup = WorkspaceSetupModel()
    color_dir = wrksp_setup.get_colors_path()
    map_list = reg_dic['Map'].split(',')
    if len(map_list) != 1:
        # load second map file first
        second_shape_file = map_list[1]
        second_shape_file_name = \
            os.path.basename(second_shape_file).replace(
                             config.SHP_SUFFIX, '')
        if second_shape_file_name in open_file_names:
            # shape file already opened in layer panel, move it to the top
            move_map_2_top(second_shape_file_name)
        else:
            # load shape file
            second_vector_lyr = \
                QgsVectorLayer(
                    second_shape_file,
                    second_shape_file_name,
                    config.OGR)
            second_qml_file = \
                os.path.join(color_dir, config.SECOND_MAP_COLOR_FILE)
            if second_vector_lyr.isValid():
                region_extents = get_region_extent(reg_dic)
                second_vector_lyr.setExtent(region_extents)
                second_vector_lyr.loadNamedStyle(second_qml_file)
                QgsProject.instance().addMapLayer(second_vector_lyr)
                move_map_2_top(second_shape_file_name)
    # load default map file
    shape_file = os.path.join(wrksp_setup.get_map_data_path(),
                              map_list[0])  # first one
    shape_file_name = \
        os.path.basename(shape_file).replace(config.SHP_SUFFIX, '')
    if shape_file_name in open_file_names:
        # shape file already opened in layer panel, move it to the top
        move_map_2_top(shape_file_name)
    else:
        # load shape file
        vector_lyr = \
            QgsVectorLayer(
                shape_file,
                shape_file_name,
                config.OGR)
        qml_file = os.path.join(color_dir, config.MAP_COLOR_FILE)
        if vector_lyr.isValid():
            region_extents = get_region_extent(reg_dic)
            vector_lyr.setExtent(region_extents)
            vector_lyr.loadNamedStyle(qml_file)
            QgsProject.instance().addMapLayer(vector_lyr)
            move_map_2_top(shape_file_name)
    iface.layerTreeView().collapseAllNodes()


def display_vector_layer(shape_file, color_dir, color_file=None):
    '''
    Function to display a vector layer in the QGIS panel. If a color
    filename is not given, uses the map color file specified. This will
    display at the full extents of the layer.
    params(string) - shape_file - Shapefile to display.
    params(string) - color_dir - The path to the color directory
    params(string) - color_file - The color filename(optional)
    '''
    vector_lyr =\
        QgsVectorLayer(
            shape_file,
            os.path.basename(shape_file).replace(config.SHP_SUFFIX, ''),
            config.OGR)
    if color_file:
        qml_file = os.path.join(color_dir, color_file)
    else:
        qml_file = os.path.join(color_dir, config.MAP_COLOR_FILE)
    if vector_lyr.isValid():
        vector_lyr.loadNamedStyle(qml_file)
        QgsProject.instance().addMapLayer(vector_lyr)
    iface.layerTreeView().collapseAllNodes()


def extract_raster_layer(src_filename):
    '''
    Extract layer from file.
    params(string): src_filename - Data filename.
    Returns(raster):  src_lyr - Raster representation of input.
    '''
    try:
        data_file_info = QFileInfo(src_filename)
        data_basename = data_file_info.baseName()
        src_lyr = QgsRasterLayer(src_filename, data_basename)
        src_lyr.setCrs(CRS)
        if not src_lyr.isValid():
            raise IOError(u"Error, Invalid layer!")
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - Unspecified error when extracting raster layer',
            level=Qgis.Critical)

    return src_lyr


def extract_raster_file_params(src_filename):
    '''
    Extract processing information
    from raster file(rows, cols, extents, cell size).
    Args:
    params(string) - src_filename - Data filename.
    Returns tuple of properties:
        extents(object) -  Extents of the input.
        col_ct(integer) - Input column count.
        row_ct(integer) - Input row count.
        cell_size(float) - Input cell size.
    '''
    try:
        if not os.path.exists(src_filename):
            raise IOError(src_filename + " does not exist!!")
        src_lyr = extract_raster_layer(src_filename)

        # src_lyr.setCrs(CRS)
        if src_lyr.isValid():
            extents = src_lyr.extent()
            col_ct = src_lyr.width()
            row_ct = src_lyr.height()
            cell_size = src_lyr.rasterUnitsPerPixelX()
        else:
            raise IOError(u"Invalid layer read:  " + src_filename)
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - Unspecified error when extracting raster info',
            level=Qgis.Critical)

    return extents, col_ct, row_ct, cell_size


def extract_shapefile_extents(src_filename):
    '''
    Extract extents information from shapefile.
    params(string) Shapefile name
    Returns tuple of properties:
        extents(ymin, ymax, xmin, xmax) -  Extents of the input.
    '''
    extent = extract_shapefile_extents_rectangle(src_filename)
    xmin = extent.xMinimum()
    ymin = extent.yMinimum()
    xmax = extent.xMaximum()
    ymax = extent.yMaximum()
    return (ymin, ymax, xmin, xmax)


def extract_shapefile_extents_rectangle(src_filename):
    '''
    Extract extents information from shapefile.
    params(string) Shapefile name
    Returns(QgsRectangle) - extent -  Extents of the input.
    '''
    extent = None
    vector_lyr = QgsVectorLayer(src_filename,
                                os.path.basename(src_filename),
                                config.OGR)
    if vector_lyr.isValid():
        extent = vector_lyr.extent()
    return extent


def get_offsets(ds_file, reg_dic):
    '''
    Gets row and column offsets for processing.
    Args:
        params(string) - ds_file - Dataset file(or mask file).
        params(object) - reg_dic - Region info data structure.
        returns(integer) - col_offset - Number of columns to offset.
        returns(integer) - row_offset - Number of rows to offset.
    '''
    try:
        col_offset = 0
        row_offset = 0
        ds_extents, _, _, _ = extract_raster_file_params(ds_file)
        # we use the mask cell size for calculating the offset
        # because inputs will be resampled to the mask cell size
        # when processing.  Most times, data cell size and mask cell
        # size already match.
        _, _, _, mask_cell_size = extract_raster_file_params(reg_dic['Mask'])

        if mask_cell_size != 0:  # just protection against divide by zero
            col_offset =\
                int((reg_dic['MinimumLongitude'] - ds_extents.xMinimum()) /
                    mask_cell_size)
            row_offset = int((ds_extents.yMaximum() -
                              reg_dic['MaximumLatitude']) /
                             mask_cell_size)
        else:
            raise RuntimeError(u"Error calculating offsets!!")
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - Unspecified error when calculating offsets',
            level=Qgis.Critical)

    return col_offset, row_offset


def get_open_files_info():
    '''
    Gets Map Panel layer information
    '''
    open_file_info =\
        [[layer.id(), layer.name()] for layer
         in QgsProject.instance().mapLayers().values()]
    return open_file_info


def get_region_extent(reg_dic):
    '''
    Gets dataset or region extent as needed by raster calculator
    params(dic) - reg_dic - Region dictionary
    returns(QgsRectangle) - extent - QgsRectangle(xmin, ymin, xmax, ymax)
    '''
    extent = QgsRectangle(reg_dic['MinimumLongitude'],
                          reg_dic['MinimumLatitude'],
                          reg_dic['MaximumLongitude'],
                          reg_dic['MaximumLatitude'])
    return extent


def map_panel_backup():
    '''
    Write map panel contents out to a temp project file
    '''
    if not os.path.exists(QTEMP_PATH):
        os.makedirs(QTEMP_PATH)
    QgsProject.instance().write(os.path.join(QTEMP_PATH, TEMP_PROJECT))
    QgsProject.instance().removeAllMapLayers()
    QgsProject.instance().clear()


def map_panel_restore():
    '''
    Restore map panel contents from a temp project file
    '''
    QgsProject.instance().read(os.path.join(QTEMP_PATH, TEMP_PROJECT))


def move_map_2_top(map_name: str) -> None:
    '''
    Move a map layer to the top.
    Clone it so it goes to the top and then remove old one, from:
    http://gis.stackexchange.com/questions/134284/
        how-to-move-layers-in-the-qgis-table-of-contents-via-pyqgis
    params(str) - map_name - The map name in Layers Panel
    '''
    lyr_instance = QgsProject.instance().mapLayersByName(
        map_name.replace(config.SHP_SUFFIX, ''))
    if lyr_instance:
        lyr_name = lyr_instance[0]
        root = QgsProject.instance().layerTreeRoot()
        temp_lyr = root.findLayer(lyr_name.id())
        my_clone = temp_lyr.clone()
        parent = temp_lyr.parent()
        parent.insertChildNode(0, my_clone)
        parent.removeChildNode(temp_lyr)


def notify_loaded_file(dialog_object: QDialog,
                       file_name_str, map_layer_id=None):
    '''
    Notify user of an open file.
    params(string) - file_name_str - The filename(s) that are open.
    params(string) - map_layer_id - The map layer id to remove.
    '''
    QMessageBox.information(dialog_object,
                            u'File already loaded',
                            os.path.splitext(file_name_str)[0] +
                            ' is open in the ' +
                            'layer panel. \n\n This layer will be removed ' +
                            'to re-run the analysis.',
                            QMessageBox.Ok)
    if map_layer_id:
        QgsProject.instance().removeMapLayer(map_layer_id)


def rectangle_equals(rect1, rect2):
    '''
    Checks equality of QgsRectangles, values are rounded to two decimal places
    before comparison.
    param(QgsRectangle) - rect1
    param(QgsRectangle) - rect2
    return(boolean) - ret_val - True if equal else false
    '''
    ret_val = False
    if round(rect1.xMinimum(), 2) == round(rect2.xMinimum(), 2) and\
            round(rect1.yMinimum(), 2) == round(rect2.yMinimum(), 2) and\
            round(rect1.xMaximum(), 2) == round(rect2.xMaximum(), 2) and\
            round(rect1.yMaximum(), 2) == round(rect2.yMaximum(), 2):
        ret_val = True
    return ret_val


def rstr_calc_sum_files_raw(file_list, no_data_val,
                            scale_factor, dst_filename):
    '''
    Function to sum a List of files. Doesn't care about region size or mask
    Just sums the files as is. Resulting file has specified type
    params(list) - file_list - Source files
    params(int) - no_data_val - Source rasters' no data value
    params(int) - scale_factor - Scale factor for final file
    params(string) - dst_filename - Masked output file.
    Returns(int) - result - Result of raster calculator calculation(0=success).
    '''
    if scale_factor == 0:
        scale_factor = 1  # protect from divide by zero

    extents, col_ct, row_ct, _ = extract_raster_file_params(file_list[0])
    nd_val = str(no_data_val)
    layer_list = []
    calculator_entry_list = []
    ext_l = os.path.splitext(file_list[0])[1]
    if ext_l == config.TIFF_SUFFIX or config.TIF_SUFFIX:
        drv_txt = config.GTIFF
    else:
        drv_txt = config.EHDR

    for entry in file_list:
        new_layer = extract_raster_layer(entry)
        layer_list.append(new_layer)
    count = 1
    for lyr in layer_list:
        boh = QgsRasterCalculatorEntry()
        boh.ref = "boh@" + str(count)
        boh.raster = lyr
        boh.bandNumber = BAND_NUM
        calculator_entry_list.append(boh)
        if count == 1:
            formula_nodata =\
              "(((" + boh.ref + " = " + nd_val + ")"
            formula_data = "(((" + boh.ref + " != " + nd_val + ")"
            formula_sum = "(" + boh.ref
        else:
            formula_nodata =\
                formula_nodata + " OR (" + boh.ref + " = " + nd_val + ")"
            formula_data = (formula_data + " AND (" +
                            boh.ref + " != " + nd_val + ")")
            formula_sum = formula_sum + " + " + boh.ref
        count += 1

    formula_nodata = formula_nodata + ") * " + nd_val + ")"
    formula_data = formula_data + ") * "
    formula_sum = formula_sum + ') / ' + str(scale_factor) + ")"
    formula = formula_nodata + " + " + formula_data + formula_sum

    # Process calculation with input extent and resolution
    calc = QgsRasterCalculator(formula,
                                dst_filename,
                                drv_txt,
                                extents,
                                col_ct,
                                row_ct,
                                calculator_entry_list,
                                XFORM_CONTEXT)
    result = calc.processCalculation()
    if result != QgsRasterCalculator.Success:
        QgsMessageLog.logMessage('Exception - raster calculation failed',
                                  level=Qgis.Critical)

    return result


def write_image(file_list, color_dir, qml_file=None):
    '''
    Function to write out an image of output.  This will clear the map
    panel so if the contents of the map panel are needed, they should
    be saved before calling this function.
    params(list) - file_list - List of raster, station shapefile, and map file,
            in that order.
    params(string) - color_dir
    params(string) - qml_file
    '''
    file_suffix = os.path.splitext(file_list[0])[1]
    # clear any existing layers
    QgsProject.instance().removeAllMapLayers()
    station_shape_file = None
    map_shape_file = None
    # load the specified layers
    raster_filename = file_list[0]
    try:
        station_shape_file = file_list[1]
        map_shape_file = file_list[2]
    except IndexError:
        pass
    jpg_dst_file = WorkspaceSetupModel().fix_os_sep_in_path(
        raster_filename.replace(file_suffix, config.JPG_SUFFIX))
    if not qml_file:
        qml_file = config.RF_2500_RASTER_COLOR_FILE
    display_raster_layer(file_list[0], qml_file)
    if station_shape_file:
        display_vector_layer(station_shape_file,
                             color_dir,
                             color_file=config.RF_2500_PTS_COLOR_FILE)
    if map_shape_file:
        display_vector_layer(map_shape_file, color_dir)
    # save the file
    iface.mapCanvas().saveAsImage(jpg_dst_file)
