'''
/***************************************************************************
Name		: basiics_batch_utilites.py
Description : Batch Utilities - for BASIICS, ValidateRFE and Interpolate
                   Stations
This code adapted from: https://snorfalorpagus.net/blog/2013/12/07/
                             multithreading-in-qgis-python-plugins/
copyright   : (C) 2020 - 2023 by FEWS
email       : minxuansun@contractor.usgs.gov
Created:    : 01/23/2020
Modified    : 04/14/2020 cholen - Fixed header row check
              04/17/2020 cholen - Fixed AVGDATADATEFORMAT in
                                  get_datestring_reg_expression
              06/08/2020 cholen - Remove tab from stats prints
              06/17/2020 cholen - Specify utf-8 when opening csv files
              06/23/2020 cholen - Add read_csv_file function
              06/29/2020 cholen - Exclude masked stations in interp list
              07/14/2020 cholen - Change raises to logged messages
              10/08/2020 cholen - Fix check for fill value on array interp
              10/23/2020 cholen - Fix for using ds extents
              11/17/2020 cholen - Check point vs region or dataset dependent on
                                  what analysis is selected
              12/03/2020 cholen - Handle OSError
              01/13/2021 cholen - Handle regex problems
              05/10/2021 cholen - Fix MMK and MMP date format problems with
                                  finding columns in csv
              06/04/2021 cholen - Handle bad station values
              01/18/2022 cholen - New gdal utilities
              04/25/2022 cholen - Add tif and new date format support
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import csv
import datetime
import glob
import math
import os
import re
import shutil
import matplotlib.pyplot as plt
import numpy

from qgis.core import QgsMessageLog, Qgis
from qgis.core import QgsPointXY, QgsRectangle

from fews_tools.utilities import geoclim_gdal_utilities as g_util
from fews_tools.utilities import geoclim_utilities as util
from fews_tools.utilities import geoclim_qgs_utilities as qgs_util
from fews_tools import fews_tools_config as config

"""
Blending station dictionary keys:

'Stn_id': ie: 1119   comes from source csv file
'Longitude': float(row[1]), comes from source csv file
'Latitude': float(row[2]), comes from source csv file
'Stn_val': float(row[3]), comes from source csv file
'Grid_val': ie 247  grid values at the station locations -
                   comes from source data file for specified date
'Intrpltd_stn_val': float  calculated from 'Stn_val'
'Xvalidated_stn_val': float  calculated from 'Stn_val'
'Ratio':  float  calculated from ('Intrpltd_stn_val'/'Grid_val')
'Interp_ratio_val': float  calculated frm 'Ratio'
'Xvalidated_ratio_val': float  calculated from 'Ratio'
'Anomaly_val': float ('Intrpltd_stn_val' - ('Grid_val' * 'Interp_ratio_val' )
'Interp_anomaly_val': float  calculated from 'Anomaly_val'
'Xvalidated_anomaly_val': float  calculated from 'Anomaly_val'
'Interp_array_val': float  comes from the interpolated array value
                           at the station locations
"""

EARTH_RADIUS = 6370.997  # km
DEG2RAD = math.pi / 180.0
DEGPERKM = 0.0089  # approximation of degrees per km(use value at equator)
EPSILON = 10.0
ANOMMISSINGVAL = -29999.0
ANOM_ARRAY_FILL_VAL = 0.0
MINSHORT = -32767
MAXSHORT = 32767
MIN_CROSS_VAL_DIST = 0.0


def build_station_dic(bat_dic, stn_filename, grid_filename):
    '''
    Get the station dictionary started with lat, long, orig stn value
    and orig grid value
    '''
    delimiter_val = bat_dic['delimiter']
    stn_dic = get_stations_dic_for_period(stn_filename, delimiter_val)
    if stn_dic is not None:
        # add in the grid values if needed
        if bat_dic['analysis_type'] != 3:
            extract_grid_stn_loc_values(bat_dic, stn_dic, grid_filename)
    return stn_dic


def calc_station_ratios(bat_dic, stn_dic):
    '''
    Function to get the station ratios into the stn_dic.
    params(dic) - bat_dic - Batch parameters
    params(dic) - stn_dic - Station parameters
    '''
    for key in stn_dic:
        # calculate the ratios
        if stn_dic[key]['Intrpltd_stn_val'] == bat_dic['stn_missing_val']:
            stn_dic[key]['Ratio'] = bat_dic['ds_dic']['DATAMISSINGVALUE']
        else:
            # protect against divide by zero error, shouldn't happen but...
            if (stn_dic[key]['Grid_val'] + EPSILON) == 0:
                stn_dic[key]['Ratio'] = bat_dic['max_ratio']
            else:
                temp =\
                    ((stn_dic[key]['Intrpltd_stn_val'] + EPSILON) /
                     (stn_dic[key]['Grid_val'] + EPSILON))
                if temp > bat_dic['max_ratio']:
                    stn_dic[key]['Ratio'] = bat_dic['max_ratio']
                else:
                    stn_dic[key]['Ratio'] = temp


def check_point_extents(bat_dic, ds_params, data_point,
                        lat_col, long_col, year_col):
    '''
    Checks point against start/end year and either region or dataset extents
    params(dic) - bat_dic - Batch parameters
    parmams(obj) ds_params - can be 'None' if we want to check region extents
    params(list) - data_point
    params(int) - Lat col
    params(int) - Long col
    params(int) - year col
    returns(boolean) True when point if within extents, else false
    '''
    ret_val = False
    pt_obj = QgsPointXY(float(data_point[long_col]),
                        float(data_point[lat_col]))
    if ds_params:
        bbox = QgsRectangle(ds_params[0].xMinimum(),
                            ds_params[0].yMinimum(),
                            ds_params[0].xMaximum(),
                            ds_params[0].yMaximum())
    else:
        bbox = QgsRectangle(bat_dic['reg_dic']['MinimumLongitude'],
                            bat_dic['reg_dic']['MinimumLatitude'],
                            bat_dic['reg_dic']['MaximumLongitude'],
                            bat_dic['reg_dic']['MaximumLatitude'])
    yr_range = range(int(bat_dic['start_year']), int(bat_dic['end_year']) + 1)
    if bbox.contains(pt_obj) and int(data_point[year_col]) in yr_range:
        ret_val = True
    return ret_val


def create_output_raster_file(
        bat_dic, ds_params, src_array, temp_path, name_base, anom_flag=False):
    '''
    Create output file, this will clip to the region
    extents as a final step
    '''
    if anom_flag:
        dst_base_filename =\
            name_base + '_anom' + bat_dic['ds_dic']['DATASUFFIX']
    else:
        dst_base_filename =\
            name_base + bat_dic['ds_dic']['DATASUFFIX']
    dst_full_extents_filename = os.path.join(
        temp_path, dst_base_filename)
    dst_clipped_filename =\
        os.path.join(bat_dic['curr_output_path'],
                     dst_base_filename)
    if anom_flag:
        src_array.astype(numpy.int16)
    err = g_util.write_file(
        dst_full_extents_filename, src_array,
        ds_params[1], ds_params[2], bat_dic['geotransform'],
        g_util.TYPE_DIC["Int16"]["GDAL"])
    if err is True:
        raise RuntimeError
    g_util.clip_raster_to_region(bat_dic['reg_dic'],
                                 dst_full_extents_filename,
                                 dst_clipped_filename)

    QgsMessageLog.logMessage(dst_base_filename + ' complete',
                             level=Qgis.Info)
    try:  # cleanup temp files
        temp_file_list =\
            glob.glob(temp_path + os.sep + name_base + '*')
        for entry in temp_file_list:
            os.remove(entry)
    except OSError:
        pass
    return dst_clipped_filename


def extract_grid_stn_loc_values(bat_dic, stn_dic, grid_filename):
    '''
    Function to extract the rainfall grid values at specified stn locations
    and the rainfall file as a grid.
    params(dic) - bat_dic - Batch parameters
    params(dic) - stn_dic - Station parameters
    params(string) - grid_filename - Filename of the rainfall file.
    '''
    _, _, _, _, data_type = g_util.get_geotiff_info(grid_filename)
    np_data_type = g_util.TYPE_DIC[data_type]["NP"]
    grid_array = g_util.extract_raster_array(grid_filename, np_data_type)
    # params = (extents, col_ct, row_ct, cell_size)
    params = qgs_util.extract_raster_file_params(grid_filename)
    for key in stn_dic:
        pt_col = int((stn_dic[key]['Longitude'] -
                            params[0].xMinimum()) / params[3])
        pt_row = int((params[0].yMaximum() -
                            stn_dic[key]['Latitude']) / params[3])

        if params[1] > pt_col >= 0 and params[2] > pt_row >= 0:
            stn_dic[key]['Grid_val'] = (grid_array[pt_row, pt_col])
        else:
            stn_dic[key]['Grid_val'] = bat_dic['ds_dic']['DATAMISSINGVALUE']


def extract_point_stats(bat_dic, ds_params, stn_dic,
                        key_val, src_array):
    '''
    Function to extract grid values from an input array at specified
    stn locations.  Adds to station dictionary
    params(dic) - bat_dic - Batch parameters
    params(tuple) - ds_params - Raster parameters
    params(dic) - stn_dic - Station parameters
    params(array) - src_array - Image array of floats.
    '''
    for key in stn_dic:
        pt_col =\
            int((stn_dic[key]['Longitude'] -
                       ds_params[0].xMinimum()) /
                      ds_params[3])
        pt_row =\
            int((ds_params[0].yMaximum() -
                       stn_dic[key]['Latitude']) /
                      ds_params[3])

        if pt_col < ds_params[1] and pt_col >= 0 and\
           pt_row < ds_params[2] and pt_row >= 0:
            stn_dic[key][key_val] = src_array[pt_row, pt_col]
        else:
            stn_dic[key][key_val] = bat_dic['ds_dic']['DATAMISSINGVALUE']


def get_csv(file_list, sample):
    '''
    Verify csv filename and check existance
    '''
    err = False
    csv_path = os.path.split(file_list[0])[0]
    sample_date_csv_filename_l = os.path.join(
        csv_path, 'raingauge' + sample + '.csv')
    if not os.path.exists(sample_date_csv_filename_l):
        QgsMessageLog.logMessage(('CSV file missing: ' +
                                  str(sample_date_csv_filename_l)),
                                 level=Qgis.Critical)
        err = True
    return sample_date_csv_filename_l, err


def get_datestring_reg_expression(bat_dic):
    '''
    Determine regular expression from date format
    params(obj) - bat_dic - Batch parameter dictionary
    '''
    if bat_dic['input_prefix'] == bat_dic['ds_dic']['DATAPREFIX']:
        reg_exp_l = config.DATE_FORMATS_DIC[
            bat_dic['ds_dic']['DATADATEFORMAT']]['REG_EXPR']
    else:
        reg_exp_l = config.DATE_FORMATS_DIC[
            bat_dic['ds_dic']['AVGDATADATEFORMAT']]['REG_EXPR']
    return reg_exp_l


def get_distance(long1, lat1, long2, lat2):
    '''
    Function to calculate the distance between two locations in km.
    params(float) - long1 - Longitude 1
    params(float) - lat1 - Latitude 1
    params(float) - long2 - Longitude 2
    params(float) - lat2 - Latitude 2
    returns(float) - dist - Distance in km
    '''
    # calculate distance of pixel:
    cos_dist = (math.sin(lat1 * DEG2RAD) *
                math.sin(lat2 * DEG2RAD) +
                math.cos(lat1 * DEG2RAD) *
                math.cos(lat2 * DEG2RAD) *
                math.cos((long2 - long1) * DEG2RAD))

    # sometimes rounding gives acos(>1.0), an error value
    if cos_dist > 1:
        cos_dist = 1.0
    dist = math.acos(cos_dist) * EARTH_RADIUS
    # protect against a divide by 0 crash later
    if dist == 0:
        dist = 0.1
    return dist


def get_fuzzy_dist(bat_dic, ds_params):
    '''
    Function to get the fuzzy distance for idw calculations.
    params(dic) - bat_dic - Batch parameters
    params(tuple) - ds_params - Raster parameters
    returns(float) - fuzz_dist
    '''
    # Get the distance to the edge of the pixel in the middle of the grid
    mid_grid_long = (ds_params[0].xMinimum() +
                     (ds_params[1] * ds_params[3] / 2))
    mid_grid_lat = (ds_params[0].yMaximum() -
                    (ds_params[2] * ds_params[3] / 2))
    mid_grid_long2 = mid_grid_long + ds_params[3]
    mid_grid_lat2 = mid_grid_lat + ds_params[3]

    dist = get_distance(mid_grid_long, mid_grid_lat,
                        mid_grid_long2, mid_grid_lat2)
    bat_dic['fuzzy_distance'] = dist * bat_dic['fuzz_factor']


def get_grid(bat_dic, sample, reg_exp):
    '''
    Verify grid filename and check existance
    '''
    err = False
    for data_file in bat_dic['good_files_list']:
        # have to split down to what we think is the date string
        # and then check if it matches the expected length of the date string
        # because prefix and/or file path could contain digits that would
        # return an incorrect value if we just use the regex alone
        date_string = os.path.splitext(
                        os.path.basename(data_file))[0].split(
                            bat_dic['ds_dic']['DATAPREFIX'])[1]
        # in case there are bad or weird filenames, like -Copy etc
        try:
            dt_st_ck = re.findall(reg_exp, date_string)[0]
        except IndexError:
            continue
        if dt_st_ck != date_string:
            continue

        if bat_dic['input_prefix'] == bat_dic['ds_dic']['DATAPREFIX']:
            if sample == date_string:
                grid_filename_l = data_file
                break
        else:
            _, _, per = util.split_date_string(
                date_string, bat_dic['ds_dic']['DATADATEFORMAT'])
            if sample[-2:] == per:
                grid_filename_l = data_file
                break
    if not os.path.exists(grid_filename_l):
        QgsMessageLog.logMessage(('Raster file missing: ' +
                                  str(grid_filename_l)),
                                 level=Qgis.Critical)
        err = True
    return grid_filename_l, err


def get_station_2_pixel_list(bat_dic, ds_params, stn_filename):
    '''
    Function to find stations closest to each pixel.
    params(dic) - bat_dic - Batch parameters
    params(tuple) - ds_params - Raster parameters
    params(string)- stn_filename - Filename for list of region stns.
    returns(list) - stn_2_pixel_list
    returns(boolean) - err - True if no list produced else False
    '''
    stn_2_pixel_list = []
    err = False
    search_radius = bat_dic['search_radius'] * DEGPERKM
    try:
        data_list = read_csv_file(stn_filename, bat_dic['delimiter'], 0)
        long_start_pt =\
            (ds_params[0].xMinimum() + (ds_params[3] / 2))
        lat_start_pt =\
            (ds_params[0].yMaximum() - (ds_params[3] / 2))

        src_array = g_util.extract_raster_array(bat_dic['good_files_list'][0])
        #  in case data is huge, calc max min lat/long of station vicinity
        #  and figure cols and rows of interest,  this uses the search radius
        stn_extents = get_station_location_extent(data_list)
        min_row, max_row, min_col, max_col = get_station_rows_columns(
            ds_params, stn_extents, search_radius)

        for (row, col), value in numpy.ndenumerate(src_array):
            # don't waste time on pixels outside of the station vicinity
            if not (row >= min_row and row <= max_row and
                    col >= min_col and col <= max_col):
                continue

            # calc lat-long of row and col,
            # note that this is the center of the pixel not the edge
            long1 = long_start_pt + (col * ds_params[3])
            lat1 = lat_start_pt - (row * ds_params[3])
            temp_min_lat = lat1 - search_radius
            temp_min_long = long1 - search_radius
            temp_max_lat = lat1 + search_radius
            temp_max_long = long1 + search_radius
            used_stn_array = []

            for ck_stn in data_list:
                # only calculate distances if in the neighborhood
                if (float(ck_stn[1]) <= temp_max_long and
                        float(ck_stn[1]) >= temp_min_long and
                        float(ck_stn[2]) <= temp_max_lat and
                        float(ck_stn[2]) >= temp_min_lat):
                    dist = get_distance(long1, lat1,
                                        float(ck_stn[1]), float(ck_stn[2]))
                    # check station against search radius
                    if dist <= bat_dic['search_radius']:
                        used_stn_array.append([ck_stn[0], dist])
            if used_stn_array:
                used_stn_array.sort(key=lambda row: row[1])
                closest_stns = [[row, col], used_stn_array]
                stn_2_pixel_list.append(closest_stns)
    except BaseException:
        QgsMessageLog.logMessage('Exception - get_station_2_pixel_list failed',
                                 level=Qgis.Critical)
    if not stn_2_pixel_list:
        QgsMessageLog.logMessage('Stations to pixel mismatch',
                                 level=Qgis.Critical)
        err = True
    else:
        QgsMessageLog.logMessage('Completed station to pixel',
                                 level=Qgis.Info)
    return stn_2_pixel_list, err


def get_station_2_station_list(bat_dic, ds_params, stn_filename):
    '''
    Function to find closest stations to stations. Will return a list
    of lists [[stn1_id, stn1_distance], ....]
    params(dic) - bat_dic - Batch parameters
    stn_filename(string) - File with station information.
    returns(list) - stn_2_stn_list
    '''
    err = False
    stn_2_stn_list = []
    search_radius = bat_dic['search_radius'] * DEGPERKM
    try:
        data_list = read_csv_file(stn_filename, bat_dic['delimiter'], 0)

        for curr_stn in data_list:

            # have this reviewed, not sure that the conversion actually
            # helps, why not just use the lat and long of the station?
            temp_col = math.floor((float(curr_stn[1]) -
                                   ds_params[0].xMinimum()) / ds_params[3])
            temp_row = math.floor((ds_params[0].yMaximum() -
                                   float(curr_stn[2])) / ds_params[3])

            long1 = (ds_params[0].xMinimum() +
                     temp_col * ds_params[3] + (ds_params[3] / 2))
            lat1 = (ds_params[0].yMaximum() -
                    temp_row * ds_params[3] - (ds_params[3] / 2))

            # end this should be reviewed

            temp_min_lat = lat1 - search_radius
            temp_min_long = long1 - search_radius
            temp_max_lat = lat1 + search_radius
            temp_max_long = long1 + search_radius

            used_stn_array = []
            for ck_stn in data_list:
                # only calculate distances if we are in the neighborhood
                if (float(ck_stn[1]) <= temp_max_long and
                        float(ck_stn[1]) >= temp_min_long and
                        float(ck_stn[2]) <= temp_max_lat and
                        float(ck_stn[2]) >= temp_min_lat):

                    dist = get_distance(long1, lat1,
                                        float(ck_stn[1]), float(ck_stn[2]))

                    # check against search radius
                    if dist <= bat_dic['search_radius']:
                        used_stn_array.append([ck_stn[0], dist])
                        used_stn_array.sort(key=lambda row: row[1])

            stn_2_stn_list.append([curr_stn, used_stn_array])

    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - get_station_2_station_list failed',
            level=Qgis.Critical)
        err = True
    if not stn_2_stn_list:
        QgsMessageLog.logMessage(
            'Stations to station mismatch', level=Qgis.Critical)
        err = True
    else:
        QgsMessageLog.logMessage(
            'Completed station to station', level=Qgis.Info)
    return stn_2_stn_list, err


def get_stations(bat_dic, ds_params):
    '''
    Function to find stations within dataset extents. Note the controller
    has already checked for the destination files open, so is not done here.
    This will also remove any
    stations that do not have data for the selected time frame.

    params(dic) - bat_dic - Batch parameters
    ds_params(tuple) - Dataset params( when not using region )
    returns(string) - dst_csv_filename1 - Filename for region stations.
    returns(string) - dst_csv_filename2 = Filename for region station data.
    '''
    err = False
    dst_csv_filename1 = None
    dst_csv_filename2 = None
    id_col = bat_dic['csv_stn_id_col'] - 1
    long_col = bat_dic['csv_long_col'] - 1
    lat_col = bat_dic['csv_lat_col'] - 1
    year_col = bat_dic['csv_year_col'] - 1
    try:
        # first output file will be a list of region stations
        path = os.path.split(bat_dic['station_filename'])[0]
        dst_csv_filename1 = os.path.join(path, 'regionStations.csv')
        # second output is a copy of the input csv,
        # with any stations outside of the specified extents removed
        dst_csv_filename2 =\
            bat_dic['station_filename'][:-4] + '_region.csv'
        data_list = read_station_file(bat_dic)

        extracted_stns_list = []
        extracted_data_list = []
        used_ids = []

        for entry in data_list:
            # need to check that the station data is there, some files
            # have lines where missing data filled into the period columns
            # but no station id, lat, long, etc.
            # also need to check if we have data for correct year range,
            # don't want to spend time processing stations that don't even
            # have data for the years of interest
            if entry[id_col] and entry[long_col] and entry[lat_col]:
                temp_list = [entry[id_col], entry[long_col], entry[lat_col]]
                if bat_dic['analysis_type'] != 2:  # checks vs region extents
                    test = check_point_extents(bat_dic, None, entry,
                                               lat_col, long_col, year_col)
                else:    # checks vs dataset extents
                    test = check_point_extents(bat_dic, ds_params, entry,
                                               lat_col, long_col, year_col)
                if test and not entry[id_col] in used_ids:
                    extracted_stns_list.append(temp_list)
                    used_ids.append(entry[id_col])
                extracted_data_list.append(entry)

        if not used_ids or not extracted_stns_list:
            QgsMessageLog.logMessage('No stations found within region extents',
                                     level=Qgis.Warning)
        else:
            with open(dst_csv_filename1, 'w', newline='') as dst_csv:
                writer = csv.writer(dst_csv,
                                    delimiter=bat_dic['delimiter'])
                for row in extracted_stns_list:
                    writer.writerow([row[0], row[1], row[2]])
            with open(dst_csv_filename2, 'w', newline='') as dst_csv:
                writer = csv.writer(dst_csv,
                                    delimiter=bat_dic['delimiter'])
                for row in extracted_data_list:
                    writer.writerow(row)
    except BaseException:
        QgsMessageLog.logMessage('Stations within extents failed',
                                 level=Qgis.Critical)
        err = True
    if not os.path.exists(dst_csv_filename1) or\
            not os.path.exists(dst_csv_filename2):
        QgsMessageLog.logMessage('Stations and extents mismatch',
                                 level=Qgis.Critical)
        err = True
    else:
        QgsMessageLog.logMessage('Stations within extents found',
                                 level=Qgis.Info)
    return dst_csv_filename1, dst_csv_filename2, err


def get_station_location_extent(data_list):
    '''
    Figure out the extents of the regions
    '''
    x_min = 180.0
    x_max = -180.0
    y_min = 50.0
    y_max = -50.0

    for entry in data_list:
        temp_long = float(entry[1])
        temp_lat = float(entry[2])
        if temp_long < x_min:
            x_min = temp_long
        if temp_long > x_max:
            x_max = temp_long
        if temp_lat < y_min:
            y_min = temp_lat
        if temp_lat > y_max:
            y_max = temp_lat

    extent = QgsRectangle(x_min, y_min, x_max, y_max)
    return extent


def get_station_rows_columns(ds_params, stn_extents, search_radius):
    '''
    Gets row and column station offsets for processing. Add search radius to
    the calcuation.

    '''
    min_stn_long = stn_extents.xMinimum() - search_radius
    min_col = int((min_stn_long - ds_params[0].xMinimum()) / ds_params[3])
    if min_col < 0:
        min_col = 0
    max_stn_long = stn_extents.xMaximum() + search_radius
    max_col = int((max_stn_long - ds_params[0].xMinimum()) / ds_params[3])
    if max_col > ds_params[1]:
        max_col = ds_params[1]
    # for lats, min lat corresponds to max row, max lat to min row
    min_stn_lat = stn_extents.yMinimum() - search_radius
    max_row = int((ds_params[0].yMaximum() - min_stn_lat) / ds_params[3])
    if max_row > ds_params[2]:
        max_row = ds_params[2]

    max_stn_lat = stn_extents.yMaximum() + search_radius
    min_row = int((ds_params[0].yMaximum() - max_stn_lat) / ds_params[3])
    if min_row < 0:
        min_row = 0
    return min_row, max_row, min_col, max_col


def get_value_average(stn_dic, vals_key):
    '''
    Gets the average of one set of value_key values from the station dictionary
    params(dic) - stn_dic - Station parameters
    params(string) - vals_key - Key whose values to average
    returns(float) - avg - Average of values
    '''
    sum_l = 0.0
    for entry in stn_dic:
        sum_l += stn_dic[entry][vals_key]
    avg = sum_l / float(len(stn_dic))
    return avg


def interpolate_array_idw(bat_dic, ds_params, stn_dic,
                          interp_type, stn_2_pixel_list, vals_key,
                          fill_val=None):
    '''
    Interpolation IDW function.  This can be used for both ordinary and simple.
    This routine is adapted from the Windows GeoCLIM version was based on:
    http://mathworld.wolfram.com/SphericalTrigonometry.html).
    params(dic) - bat_dic - Batch parameters
    params(dic) - stn_dic - Station parameters
    params(string) stn_2_pixel_list - Station to pixel list
    params(list) - vals_key - Float station values(measured, ratio,etc).
    params(float) - fill_val - Value to start the array
    returns(array) - dst_interp_array - The array of float interpolated values.
    '''
    if fill_val is None:
        fill_val = float(bat_dic['ds_dic']['DATAMISSINGVALUE'])

    if not stn_2_pixel_list:
        QgsMessageLog.logMessage('Stations to pixels mismatch',
                                 level=Qgis.Critical)
        dst_interp_array = None
    else:
        dst_interp_array = numpy.full((ds_params[2], ds_params[1]), fill_val)
        valid_avg = 0
        # get starting points for sum_wts and sum_wtd_vals(used in loop)
        if interp_type == 'Simple':
            if bat_dic['lr_value'] == bat_dic['ds_dic']['DATAMISSINGVALUE']:
                valid_avg = get_value_average(stn_dic, vals_key)
            else:
                valid_avg = bat_dic['lr_value']
            sum_wts_start =\
                1 / math.pow((bat_dic['back_eq_distance'] +
                              bat_dic['fuzzy_distance']),
                             bat_dic['wt_power'])
            sum_wtd_vals_start = valid_avg * sum_wts_start
        else:
            sum_wts_start = 0.0
            sum_wtd_vals_start = 0.0

        for entry in stn_2_pixel_list:
            used_stn_array = []
            for nearby_stn in entry[1]:
                # make sure it's in stn_dic and not masked...
                if nearby_stn[0] in stn_dic and not\
                        numpy.ma.is_masked(stn_dic[nearby_stn[0]][vals_key]):
                    used_stn_array.append(
                        [stn_dic[nearby_stn[0]][vals_key], nearby_stn[1]])
            if len(used_stn_array) >= bat_dic['max_stations']:
                used_stn_array = used_stn_array[:bat_dic['max_stations']]

            # reset sum_wts and sum_wtd_vals
            sum_wts = sum_wts_start
            sum_wtd_vals = sum_wtd_vals_start

            if len(used_stn_array) >= bat_dic['min_stations']:
                for stn in used_stn_array:
                    temp = 1 / math.pow((stn[1] + bat_dic['fuzzy_distance']),
                                        bat_dic['wt_power'])
                    sum_wts = sum_wts + temp
                    sum_wtd_vals = sum_wtd_vals + (stn[0] * temp)
                if sum_wts != 0:
                    dst_interp_array[entry[0][0]][entry[0][1]] =\
                        sum_wtd_vals / sum_wts
    return dst_interp_array


def cointerpolate_stations_idw(bat_dic, stn_dic, interp_type, stn_2_stn_list,
                               src_vals_key, intrpltd_key, x_val_key):
    '''
    Cointerpolation IDW function.  This can be used for both ordinary and
    simple.  This routine is adapted from the Windows GeoCLIM version
    based on: http://mathworld.wolfram.com/SphericalTrigonometry.html). The
    interp type is passed as a parameter because it doesn't always use the
    batch dictionary value
    params(dic) - bat_dic - Batch parameters
    params(dic) - stn_dic - Station parameters
    parrams(string) - interp_type - Interp type
    params(string) -stn_2_stn_list - Station to station list
    params(string) - src_vals_key - Source value key
    params(string) - intrpltd_key - Interpolated value key
    params(string) - x_val_key - Cross validated value key
    '''
    valid_avg = 0.0
    fuzz_dist = bat_dic['fuzzy_distance']

    if interp_type == 'Simple':
        if bat_dic['lr_value'] == bat_dic['stn_missing_val']:
            valid_avg = get_value_average(stn_dic, src_vals_key)
        else:
            valid_avg = bat_dic['lr_value']

    # we want used_stn_array to be a list of lists
    # [[stn1_id, [inner_list1]], [stn2_id, [inner_list2]]...]
    # the inner lists are [station_value(float),station_distance(float)]
    for key in stn_dic:
        used_stn_array = []
        for entry in stn_2_stn_list:
            if key == entry[0][0]:
                for nearby_stn in entry[1]:
                    # nearby_stn[0] = station_id
                    # nearby_stn[1] = distance to station
                    if nearby_stn[0] in stn_dic:
                        stn_val = stn_dic[nearby_stn[0]][src_vals_key]
                        used_stn_array.append([stn_val, nearby_stn[1]])
        # the station itself is the first 'used stn' in the list so we
        # remove it for the cross validation
        used_stns_1_removed_array = used_stn_array[1:]
        temp_index = 0
        remove_indexes_list = []
        for entry in used_stns_1_removed_array:
            if entry[1] < 0.1 or entry[1] < MIN_CROSS_VAL_DIST:
                remove_indexes_list.append(temp_index)
        # sort so later indexes are removed first
        remove_indexes_list = sorted(set(remove_indexes_list),
                                     key=None, reverse=True)
        for idx in remove_indexes_list:
            del used_stns_1_removed_array[idx]

        if len(used_stn_array) >= bat_dic['max_stations']:
            used_stn_array = used_stn_array[:bat_dic['max_stations']]
        if len(used_stns_1_removed_array) >= bat_dic['max_stations']:
            used_stns_1_removed_array =\
                used_stns_1_removed_array[:bat_dic['max_stations']]

        # reset the starting point for the sum wts.
        if interp_type == 'Simple':
            temp = 1 / math.pow((bat_dic['back_eq_distance'] + fuzz_dist),
                                bat_dic['wt_power'])
            sum_wts = temp
            sum_wtd_val = valid_avg * temp
        else:
            sum_wts = 0.0
            sum_wtd_val = 0.0

        if len(used_stn_array) >= bat_dic['min_stations']:
            for stn in used_stn_array:
                temp = 1 / math.pow((stn[1] + fuzz_dist),
                                    bat_dic['wt_power'])
                sum_wts = sum_wts + temp
                sum_wtd_val = sum_wtd_val + (stn[0] * temp)
            if sum_wts != 0:
                val = (sum_wtd_val / sum_wts)
                stn_dic[key][intrpltd_key] = val
            else:
                stn_dic[key][intrpltd_key] =\
                    float(bat_dic['stn_missing_val'])
        else:
            stn_dic[key][intrpltd_key] =\
                    float(bat_dic['stn_missing_val'])

        # reset the starting point for the sum wts.
        if interp_type == 'Simple':
            temp = 1 / math.pow((bat_dic['back_eq_distance'] + fuzz_dist),
                                bat_dic['wt_power'])
            sum_wts = temp
            sum_wtd_val = valid_avg * temp
        else:
            sum_wts = 0.0
            sum_wtd_val = 0.0

        if len(used_stns_1_removed_array) >= bat_dic['min_stations']:
            for stn in used_stns_1_removed_array:
                temp = 1 / math.pow((stn[1] + fuzz_dist),
                                    bat_dic['wt_power'])
                sum_wts = sum_wts + temp
                sum_wtd_val = sum_wtd_val + (stn[0] * temp)
            if sum_wts != 0:
                val = (sum_wtd_val / sum_wts)
                stn_dic[key][x_val_key] = val
            else:
                stn_dic[key][x_val_key] =\
                    float(bat_dic['stn_missing_val'])
        else:
            stn_dic[key][x_val_key] =\
                float(bat_dic['stn_missing_val'])


def plot_graph_dic(stn_dic, key_x, key_y, label_x, label_y, title, dst_file):
    '''
    Stats plotting function
    params(dic) - Station dictionary
    params(string) - Key for values for x dimension
    params(string) - Key for values for y dimension
    params(string)- Label for x dimension
    params(string) - Label for y dimension
    params(string) - Title for plot
    params(string) - Destination file(.jpg format)
    '''
    vals_x = [value[key_x] for key, value in stn_dic.items()]
    vals_y = [value[key_y] for key, value in stn_dic.items()]
    plot_graph_list(vals_x, vals_y, label_x, label_y, title, dst_file)


def plot_graph_list(vals_x, vals_y, label_x, label_y, title, dst_file):
    '''
    Stats plotting function
    params(list) - Values for x dimension
    params(list) - Values for y dimension
    params(string)- Label for x dimension
    params(string) - Label for y dimension
    params(string) - Title for plot
    params(string) - Destination file(.jpg format)
    '''
    max_x = max(vals_x)
    max_y = max(vals_y)
    try:
        if os.path.exists(dst_file):
            os.remove(dst_file)
    except OSError:
        QgsMessageLog.logMessage(
            'File open in another process:  ' + dst_file,
            level=Qgis.Critical)
    plot_lim = max(max_x, max_y)
    plot_lim = (int(plot_lim / 100) + 1) * 100
    plt.plot(vals_x, vals_y, 'b.')
    plt.title(title)
    plt.xlabel(label_x)
    plt.xlim(0, plot_lim)
    plt.ylabel(label_y)
    plt.ylim(0, plot_lim)
    plt.savefig(dst_file)  # .png crashes Windows, use .jpg
    plt.clf()
    plt.cla()
    plt.close()


def print_cross_validation(stn_dic, k_val_x, k_val_y, dst_stats_file):
    '''
    Function to print out stats results to text file.  For each staton location
    this prints the location, BASIICS value, and Crossvalidated BASIICS value.
    Args:
        params(dic) - stn_dic - Station information
        k_val_x(list(float) - Key for Interpolated values.
        k_val_y(list(float) - Key for Interpolated values (station removed).
        dst_stats_file(string) - Name of the output file.
    '''
    try:
        with open(dst_stats_file, 'a') as stats:  # always append here
            stats.write('\n===================================\n')
            stats.write('Long, Lat, X, Y\n')

            for key in stn_dic:
                stats.write(str(round(stn_dic[key]['Longitude'], 3)) +
                            ', ' +
                            str(round(stn_dic[key]['Latitude'], 3)) +
                            ', ' +
                            str(round(stn_dic[key][k_val_x], 2)) + ', ' +
                            str(round(stn_dic[key][k_val_y], 2)) + '\n')
    except BaseException:
        QgsMessageLog.logMessage('Exception - print_cross_validation failed',
                                 level=Qgis.Critical)


def print_least_squares_stats(stats_dic, dst_stats_filename,
                              title, append_flag=False):
    '''
    Function to print out stats results to text file.
    params(list) - stats_dic - Statistical results dictionary.
    params(string) - dst_stats_filename - Name of the output file.
    params(string) - title - Title
    params(boolean) - append_flag - Flag to append to existing output.
    '''
    try:
        if append_flag:
            mode = 'a'
        else:
            mode = 'w'
        with open(dst_stats_filename, mode) as dst:
            time_string_l = datetime.datetime.now().strftime(
                '%Y/%m/%d %I:%M:%S %p')
            dst.write('\n' + title + '\n')
            dst.write('Statistical Summary For Validation ' +
                      'Regression of X on Y:\n')
            dst.write('R-squared:            \t' +
                      str(round(stats_dic['res_squared'], 2)) + '\n')
            dst.write('RMSE:                 \t' +
                      str(round(stats_dic['std_err_est'], 2)) + '\n')
            dst.write('Mean Absolute Error:  \t' +
                      str(round(stats_dic['mae'], 2)) + '\n')
            dst.write('Mean Bias:            \t' +
                      str(round(stats_dic['mean_bias'], 2)) + '\n')
            dst.write('Regression Slope:     \t' +
                      str(round(stats_dic['slope'], 2)) + '\n')
            dst.write('Regression Intercept: \t' +
                      str(round(stats_dic['intercept'], 2)) + '\n')
            dst.write('Number of valid obs:  \t' +
                      str(round(stats_dic['valid_obs_count'], 0)) + '\n\n')
            dst.write('Data printed at:   \t' + time_string_l + '\n')
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - print_least_squares_stats failed',
            level=Qgis.Critical)


def print_station_info(bat_dic, stn_dic, dst_file, sample_date=None):
    '''
    Function to print out stats results to text file.  For each staton location
    this prints the name, location, value, BASIICS value, and grid value.
        params(dic) - bat_dic - Batch parameters
        params(dic) - stn_dic - Station information
        dst_file(string) - Name of the output file.
        sample_date - can be none
    '''
    spc = bat_dic['delimiter']
    try:
        with open(dst_file, 'a') as stats:
            if sample_date:
                s_date = sample_date + bat_dic['delimiter']
            else:
                s_date = ''

            if bat_dic['analysis_type'] != 3:
                last_element = 'Grid_val'
                if not sample_date:
                    stats.write('\n===================================\n')
                    stats.write('Stn_ID' + spc + 'Longitude' + spc +
                                'Latitude' + spc + 'Stn_value' + spc +
                                'Intrpltd_stn_val' + spc + 'Grid_val\n')
            else:
                last_element = 'Xvalidated_stn_val'
                if not sample_date:
                    stats.write('\n===================================\n')
                    stats.write('Stn_ID' + spc + 'Longitude' + spc +
                                'Latitude' + spc + 'Stn_value' + spc +
                                'Intrpltd_stn_val' + spc +
                                'Cross_validated_stn_val\n')

            for key in stn_dic:
                stats.write(
                    key + spc + s_date +
                    str(round(stn_dic[key]['Longitude'], 3)) + spc +
                    str(round(stn_dic[key]['Latitude'], 3)) + spc +
                    str(round(stn_dic[key]['Stn_val'], 2)) + spc +
                    str(round(stn_dic[key]['Intrpltd_stn_val'], 2)) + spc +
                    str(round(stn_dic[key][last_element], 2)) + '\n')
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - print_station_info failed', level=Qgis.Critical)


def print_station_info_cumulative(bat_dic, stn_dic,
                                  dst_file, sample_date):
    '''
    Function to print out stats results to csv file.  For each staton location
    this prints the name, location, value, grid values, BASIICS value,
    and cross validated BASIICS with and without background values.
        params(dic) - bat_dic - Batch parameters
        params(dic) - stn_dic - Station information
        dst_file(string) - Name of the output file.
        sample_date - Date of run.
    '''
    spc = bat_dic['delimiter']
    try:
        with open(dst_file, 'a') as stats:
            # print column headings must already be printed
            s_date = sample_date + bat_dic['delimiter']

            for key in stn_dic:
                stats.write(
                    key + spc + s_date +
                    str(round(stn_dic[key]['Longitude'], 3)) + spc +
                    str(round(stn_dic[key]['Latitude'], 3)) + spc +
                    str(round(stn_dic[key]['Stn_val'], 2)) + spc +
                    str(round(stn_dic[key]['Grid_val'], 2)) + spc +
                    str(round(stn_dic[key]['final_basiics_val'], 2)) + spc +
                    str(round(stn_dic[key]['Xvalidated_ratio_plus_anom_val'],
                              2)) + spc +
                    str(round(stn_dic[key]['Xvalidated_stn_val'], 2)) + '\n')
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - print_station_info failed', level=Qgis.Critical)


def get_stations_dic_for_period(stn_filename, delimiter_val):
    '''
    Function to read the stations csv file for a single sample date.
    params(dic) - bat_dic - Batch parameters
    params(string) - stn_filename - Filename of the sample date csv file.
    returns(list) - stn_dic - Station parameters
    '''
    try:
        data_list = read_csv_file(stn_filename, delimiter_val, 0)
        stn_dic = {}
        for entry in data_list:
            stn_dic_element =\
                {'Longitude': float(entry[1]),
                 'Latitude': float(entry[2]),
                 'Stn_val': float(entry[3])}
            # key is the station id
            stn_dic[entry[0]] = stn_dic_element
    except BaseException:
        QgsMessageLog.logMessage(
            'Exception - CSV station file read failed', level=Qgis.Critical)
    return stn_dic


def read_station_file(bat_dic):
    '''
    Reads the station file, must be utf-8 encoding
    '''
    data_list = []
    try:
        # read in original station file, has all stations all sample dates
        temp_data_list = read_csv_file(bat_dic['station_filename'],
                                       bat_dic['delimiter'],
                                       bat_dic['csv_hdr_row'])
        for entry in temp_data_list:
            # deal with decimal point represented by a ','
            temp_list = [item.replace(',', '.') for item in entry]
            data_list.append(temp_list)
    except BaseException:
        QgsMessageLog.logMessage(
            ('Exception - Unable to read station file' +
             'Convert it to utf-8 format.'), level=Qgis.Critical)
    return data_list


def remove_nodata_keys(src_dic, key_val, bad_values):
    '''
    Remove any keys with no data values
    '''
    remove_keys = []
    for key in src_dic:
        if src_dic[key][key_val] in bad_values:
            remove_keys.append(key)
    #  remove specified keys from dic
    for entry in remove_keys:
        src_dic.pop(entry, None)


def station_split_csv(bat_dic, split_csv_path, region_csv_input_path):
    '''
    Function to split up a csv containing many samples of station data into
    separate files for each sample.
    New csv files are written into the data directory under Temp/RainStations.
    The input csv will require that the columns holding the rainfall amounts
    for the year must be in consecutive columns. i.e.  If we are looking at
    months and Jan = Col 8, then Feb = Col 9, Mar = Col 10, etc. It also
    requires that the 'WithinExtents' file has already been created.
    params(dic) - bat_dic - Batch parameters
    params(string) - split_csv_path - The split CSV path.
    params(string) - region_csv_input_path - The input csv file after removing
                     stations from outside of region extents.
    returns(boolean) - err - Indicates if a bad station value has been found
    '''
    err = False
    try:
        if os.path.exists(split_csv_path):
            shutil.rmtree(split_csv_path)
        os.makedirs(split_csv_path)
    except BaseException:
        QgsMessageLog.logMessage(('Unable to access workspace {}.'
                                  .format(split_csv_path)),
                                 level=Qgis.Critical)
    try:
        id_col = bat_dic['csv_stn_id_col'] - 1
        lat_col = bat_dic['csv_lat_col'] - 1
        long_col = bat_dic['csv_long_col'] - 1
        year_col = bat_dic['csv_year_col'] - 1
        beg_per_col = bat_dic['csv_beg_period_col'] - 1

        data_list = read_csv_file(region_csv_input_path,
                                  bat_dic['delimiter'], 0)
        interval_dic = util.get_interval_dic(bat_dic['ds_dic']['PERIODICITY'])
        for sample in bat_dic['dates_list']:
            per_of_year = sample[-2:]
            # if the date format is month and period of month, then we need
            # to convert it to period of year to get the right column index
            if 'MMK' in bat_dic['ds_dic']['DATADATEFORMAT'] or\
                    'MMP' in bat_dic['ds_dic']['DATADATEFORMAT']:
                per = sample[-3:]

                interval_key = util.get_key_from_period_value(
                    bat_dic['ds_dic'], per)
                per_of_year = interval_dic[interval_key]['PER_YEAR']

            elif '.' in bat_dic['ds_dic']['DATADATEFORMAT']:
                per = '.' + sample.split('.', 1)[1] if '.' in sample else ''
                interval_key = util.get_key_from_period_value(
                    bat_dic['ds_dic'], per)
                per_of_year = interval_dic[interval_key]['PER_YEAR']
            found = False
            dst_csv_filename = os.path.join(split_csv_path,
                                            'raingauge' + sample + '.csv')
            sample_col = beg_per_col + int(per_of_year) - 1
            with open(dst_csv_filename, 'w', newline='') as dst_csv_obj:
                writer = csv.writer(dst_csv_obj,
                                    delimiter=bat_dic['delimiter'])
                # ignore any missing or out of range data
                for entry in data_list:
                    try:
                        _ = float(entry[sample_col])
                    except ValueError:
                        err = True
                        QgsMessageLog.logMessage(
                            'Exception - CSV station file has bad station values',
                            level=Qgis.Critical)
                    if (sample[:4] == entry[year_col] and
                            (entry[sample_col] !=
                             str(bat_dic['stn_missing_val'])) and
                            float(entry[sample_col]) >= 0):
                        writer.writerow([entry[id_col],
                                         entry[long_col],
                                         entry[lat_col],
                                         entry[sample_col]])
                        found = True
            if not found:
                QgsMessageLog.logMessage('Missing station info '
                                         'for : ' + str(sample),
                                         level=Qgis.Info)
    except BaseException:
        QgsMessageLog.logMessage('Exception - station_split_csv failed',
                                 level=Qgis.Critical)
        err = True
    return err

def split_station_file(curr_out_path, bat_dic, region_data_filename):
    '''
    # Step 4 - split up the region data file into one for each sample date
    params(str) - curr_out_path - self.bat_dic['curr_output_path']
    params(dic) - bat_dic - Batch parameters
    params(str) - region_data_filename - Region data csv filename
    returns(list) - sample_dates_l - List of sample dates
    returns(list) - split_csv_file_list_l - List of dst csv files
    returns(boolean) - err - True if no list produced, else False
    '''
    err = False
    split_dst_path_l =\
        os.path.join(curr_out_path,
                     config.TEMP, 'RainStations')

    # get the sample dates
    err = station_split_csv(bat_dic, split_dst_path_l, region_data_filename)
    if err is False:
        split_csv_file_list_l = glob.glob(
            split_dst_path_l + os.sep + '*.csv')
        if not split_csv_file_list_l:
            QgsMessageLog.logMessage('Split station files failure',
                                     level=Qgis.Critical)
            err = True
    else:
        QgsMessageLog.logMessage('Completed split station files',
                                 level=Qgis.Info)
    return split_csv_file_list_l, err


def read_csv_file(csv_file, delim, start_row):
    '''
    Opens a utf-8 encoded csv file and returns a list of the contents
    params(string) - Filename
    params(string) - Delimiter
    params(integer) - Start row of data(dumps header rows)
    returns(list) - List of the data
    '''
    reader = None
    try:
        # open with 'r' to gather the cumulative info
        with open(csv_file, 'r', encoding='utf-8') as csv_input:
            reader = csv.reader(csv_input, delimiter=delim)
            reader = list(reader)[start_row:]  # dump the header row(s)
    except BaseException:
        QgsMessageLog.logMessage('Exception - csv read failed',
                                 level=Qgis.Critical)
    return reader
