import datetime

from utils import cached_property

import io
import os
import os.path as op
from collections import OrderedDict
from contextlib import redirect_stdout
from itertools import chain
from multiprocessing.pool import Pool
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
from netCDF4 import Dataset

from s2m.abstract_variable import AbstractVariable
from s2m.dataset.coordinates.abstract_coordinates import AbstractCoordinates
from s2m.dataset.observations.annual_maxima_observations import AnnualMaxima
from s2m.utils import Season, FrenchRegion, ALTITUDES, ORIENTATIONS, SLOPES, first_day_and_last_day, date_to_str, \
from utils import NB_CORES, classproperty

f = io.StringIO()
with redirect_stdout(f):
    from simpledbf import Dbf5

class AbstractStudy(object):
    A Study is defined by:
        - a variable class that correspond to the meteorogical quantity of interest
        - an altitude of interest
        - a start and a end year

    Les fichiers netcdf de SAFRAN et CROCUS sont autodocumentés (on peut les comprendre avec ncdump -h notamment).

    def __init__(self, variable_class: type, altitude: int = 1800,
                 year_min=None, year_max=None,
                 multiprocessing=True, orientation=None, slope=20.0,
        assert isinstance(altitude, int), type(altitude)
        assert altitude in ALTITUDES, altitude
        self.french_region = french_region
        self.altitude = altitude
        self.model_name = None
        self.variable_class = variable_class
        # Add some attributes, for the "allslopes" reanalysis
        assert orientation is None or orientation in ORIENTATIONS
        assert slope in SLOPES
        self.orientation = orientation
        self.slope = slope

        if year_min is None:
            year_min = self.YEAR_MIN
        if year_max is None:
            year_max = self.YEAR_MAX
        assert self.YEAR_MIN <= year_min <= year_max <= self.YEAR_MAX
        self.year_min = year_min
        self.year_max = year_max
        self.multiprocessing = multiprocessing
        self.season = season
        if split_years is None:
            split_years = list(range(year_min, year_max + 1))
        self.split_years = set(split_years)

        # Add some cache for computation
        self._cache_for_pointwise_fit = {}
        self._massif_names_for_cache = None

    """ Time """

    def year_to_first_index_and_last_index(self):
        year_to_first_index_and_last_index = OrderedDict()
        first_day, last_day = first_day_and_last_day(self.season)
        for year, all_days in self.year_to_all_days.items():
            year_first_index = year - 1 if self.season is not Season.spring else year
            year_last_index = year - 1 if self.season is Season.automn else year
            first_index = all_days.index('{}-{}'.format(year_first_index, first_day))
            last_index = all_days.index('{}-{}'.format(year_last_index, last_day))
            year_to_first_index_and_last_index[year] = (first_index, last_index)
        return year_to_first_index_and_last_index

    def year_to_days(self) -> OrderedDict:
        year_to_days = OrderedDict()
        for year, (start_index, last_index) in self.year_to_first_index_and_last_index.items():
            year_to_days[year] = self.year_to_all_days[year][start_index:last_index + 1]
        return year_to_days

    def year_to_all_days(self) -> OrderedDict:
        # Map each year to the 'days since year-08-01 06:00:00'
        year_to_days = OrderedDict()
        for year in self.ordered_years:
            # Load days for the full year
            date = datetime.datetime(year=year - 1, month=8, day=1, hour=6, minute=0, second=0)
            days = []
            for i in range(366):
                day = date_to_str(date)
                date += datetime.timedelta(days=1)
                if date.month == 8 and date.day == 1:
            year_to_days[year] = days
        return year_to_days

    def all_days(self):
        return list(chain(*list(self.year_to_days.values())))

    def all_daily_series(self) -> np.ndarray:
        """Return an array of approximate shape (total_number_of_days, 23) x """
        all_daily_series = np.concatenate([time_serie_array
                                           for time_serie_array in self.year_to_daily_time_serie_array.values()])
        assert len(all_daily_series) == len(self.all_days)
        return all_daily_series

    """ Annual maxima """

    def observations_annual_maxima(self) -> AnnualMaxima:
        return AnnualMaxima(df_maxima_gev=pd.DataFrame(self.year_to_annual_maxima, index=self.study_massif_names))

    def observations_annual_mean(self) -> pd.DataFrame:
        return pd.DataFrame(self.year_to_annual_mean, index=self.study_massif_names)

    def annual_maxima_and_years(self, massif_name) -> Tuple[np.ndarray, np.ndarray]:
        df = self.observations_annual_maxima.df_maxima_gev
        return df.loc[massif_name].values, np.array(df.columns)

    def year_to_annual_maxima(self) -> OrderedDict:
        # Map each year to an array of size nb_massif
        year_to_annual_maxima = OrderedDict()
        for year, time_serie in self._year_to_max_daily_time_serie.items():
            year_to_annual_maxima[year] = time_serie.max(axis=0)
        return year_to_annual_maxima

    def year_to_annual_mean(self) -> OrderedDict:
        # Map each year to an array of size nb_massif
        year_to_annual_mean = OrderedDict()
        for year, time_serie in self._year_to_max_daily_time_serie.items():
            year_to_annual_mean[year] = time_serie.mean(axis=0)
        return year_to_annual_mean

    def year_to_annual_maxima_index(self) -> OrderedDict:
        # Map each year to an array of size nb_massif
        year_to_annual_maxima = OrderedDict()
        for year, time_serie in self._year_to_max_daily_time_serie.items():
            year_to_annual_maxima[year] = time_serie.argmax(axis=0)
        return year_to_annual_maxima

    def massif_name_to_annual_maxima_index(self):
        massif_name_to_annual_maxima_index = OrderedDict()
        for i, massif_name in enumerate(self.study_massif_names):
            index = [self.year_to_annual_maxima_index[year][i] for year in self.ordered_years]
            massif_name_to_annual_maxima_index[massif_name] = index
        return massif_name_to_annual_maxima_index

    def massif_name_to_annual_maxima_angle(self):
        normalization_denominator = [366 if year % 4 == 0 else 365 for year in self.ordered_years]
        massif_name_to_annual_maxima_angle = OrderedDict()
        for massif_name, annual_maxima_index in self.massif_name_to_annual_maxima_index.items():
            angle = 2 * np.pi * np.array(annual_maxima_index) / np.array(normalization_denominator)
            massif_name_to_annual_maxima_angle[massif_name] = angle
        return massif_name_to_annual_maxima_angle

    def massif_name_to_annual_maxima(self):
        massif_name_to_annual_maxima = OrderedDict()
        for i, massif_name in enumerate(self.study_massif_names):
            maxima = np.array([self.year_to_annual_maxima[year][i] for year in self.ordered_years])
            massif_name_to_annual_maxima[massif_name] = maxima
        return massif_name_to_annual_maxima

    def massif_name_to_daily_time_series(self):
        massif_name_to_daily_time_series = OrderedDict()
        for i, massif_name in enumerate(self.study_massif_names):
            a = [self.year_to_daily_time_serie_array[year][:, i] for year in self.ordered_years]
            daily_time_series = np.array(list(chain.from_iterable(a)))
            massif_name_to_daily_time_series[massif_name] = daily_time_series
        return massif_name_to_daily_time_series

    def massif_name_to_annual_maxima_ordered_years(self):
        massif_name_to_annual_maxima_ordered_years = OrderedDict()
        for massif_name in self.study_massif_names:
            maxima = self.massif_name_to_annual_maxima[massif_name]
            annual_maxima_ordered_index = np.argsort(maxima)
            annual_maxima_ordered_years = [self.ordered_years[idx] for idx in annual_maxima_ordered_index]
            massif_name_to_annual_maxima_ordered_years[massif_name] = annual_maxima_ordered_years
        return massif_name_to_annual_maxima_ordered_years

    """ Annual total """

    def df_annual_total(self) -> pd.DataFrame:
        return pd.DataFrame(self.year_to_annual_total, index=self.study_massif_names).transpose()

    def annual_aggregation_function(self, *args, **kwargs):
        raise NotImplementedError()

    def year_to_annual_total(self) -> OrderedDict:
        # Map each year to an array of size nb_massif
        year_to_annual_mean = OrderedDict()
        for year, time_serie in self._year_to_daily_time_serie_array.items():
            year_to_annual_mean[year] = self.apply_annual_aggregation(time_serie)
        return year_to_annual_mean

    def massif_name_to_annual_total(self):
        # Map each massif to an array of size nb_years
        massif_name_to_annual_total = OrderedDict()
        for i, massif_name in enumerate(self.study_massif_names):
            maxima = np.array([self.year_to_annual_total[year][i] for year in self.ordered_years])
            massif_name_to_annual_total[massif_name] = maxima
        return massif_name_to_annual_total

    def apply_annual_aggregation(self, time_serie):
        return self.annual_aggregation_function(time_serie, axis=0)

    """ Load daily observations """

    def year_to_daily_time_serie_array(self) -> OrderedDict:
        return self._year_to_daily_time_serie_array

    def _year_to_max_daily_time_serie(self) -> OrderedDict:
        return self._year_to_daily_time_serie_array

    def _year_to_daily_time_serie_array(self) -> OrderedDict:
        # Map each year to a matrix of size 365-nb_days_consecutive+1 x nb_massifs
        year_to_daily_time_serie_array = OrderedDict()
        for year in self.ordered_years:
            # Check daily data
            daily_time_serie = self.daily_time_series(year)
            # Filter only the data corresponding:
            # 1: to treturnhe start_index and last_index of the season
            # 2: to the massifs for the altitude of interest
            assert daily_time_serie.shape == (len(self.year_to_days[year]), len(self.study_massif_names))
            year_to_daily_time_serie_array[year] = daily_time_serie
        return year_to_daily_time_serie_array

    def daily_time_series(self, year):
        daily_time_serie = self.year_to_variable_object[year].daily_time_serie_array
        nb_days = daily_time_serie.shape[0]
        assert nb_days == 365 or (nb_days == 366 and year % 4 == 0)
        assert daily_time_serie.shape[1] == len(self.column_mask)
        first_index, last_index = self.year_to_first_index_and_last_index[year]
        daily_time_serie = daily_time_serie[first_index:last_index + 1, self.column_mask]
        return daily_time_serie

    """ Load Variables and Datasets """

    def year_to_variable_object(self) -> OrderedDict:
        # Map each year to the variable array
        path_files, ordered_years = self.ordered_years_and_path_files
        return self.efficient_variable_loading(ordered_years, path_files, multiprocessing=self.multiprocessing)

    def efficient_variable_loading(self, ordered_years, arguments, multiprocessing):
        if multiprocessing:
            with Pool(NB_CORES) as p:
                variables = p.map(self.load_variable_object, arguments)
            variables = [self.load_variable_object(argument) for argument in arguments]
        return OrderedDict(zip(ordered_years, variables))

    def instantiate_variable_object(self, variable_array) -> AbstractVariable:
        return self.variable_class(variable_array)

    def load_variable_array(self, dataset):
        return np.array(dataset.variables[self.load_keyword()])

    def load_variable_object(self, path_file):
        dataset = Dataset(path_file)
        variable_array = self.load_variable_array(dataset)
        return self.instantiate_variable_object(variable_array)

    def load_keyword(self):
        return self.variable_class.keyword()

    def year_to_dataset_ordered_dict(self) -> OrderedDict:
        print('This code is quite long... '
              'You should consider year_to_variable which is way faster when multiprocessing=True')
        # Map each year to the correspond netCDF4 Dataset
        path_files, ordered_years = self.ordered_years_and_path_files
        datasets = [Dataset(path_file) for path_file in path_files]
        return OrderedDict(zip(ordered_years, datasets))

    def ordered_years_and_path_files(self):
        nc_files = [(int(f.split('_')[-2][:4]) + 1, f) for f in os.listdir(self.study_full_path) if f.endswith('.nc')]
        assert op.exists(self.study_full_path)
        assert len(nc_files) > 0
        ordered_years, path_files = zip(*[(year, op.join(self.study_full_path, nc_file))
                                          for year, nc_file in sorted(nc_files, key=lambda t: t[0])
                                          if (self.year_min <= year <= self.year_max)
                                          and (year in self.split_years)])
        return path_files, ordered_years

    """ Temporal properties """

    def nb_years(self):
        return len(self.ordered_years)

    def ordered_years(self):
        return self.ordered_years_and_path_files[1]

    def start_year_and_stop_year(self) -> Tuple[int, int]:
        ordered_years = self.ordered_years
        return ordered_years[0], ordered_years[-1]

    """ Spatial properties """

    def study_massif_names(self) -> List[str]:
        # Massif names that are present in the current study (i.e. for the current altitude)
        return self.altitude_to_massif_names[self.altitude]

    def df_latitude_longitude(self):
        any_ordered_dict = list(self.year_to_dataset_ordered_dict.values())[0]
        longitude = np.array(any_ordered_dict.variables['LON'])[self.flat_mask]
        latitude = np.array(any_ordered_dict.variables['LAT'])[self.flat_mask]
        data = [longitude, latitude]
        df = pd.DataFrame(data=data, index=['Longitude', 'Latitude'], columns=self.study_massif_names).transpose()
        return df

    def column_mask(self):
        return self.allslopes_mask if self.has_orientation else self.flat_mask

    def allslopes_mask(self):
        altitude_mask = np.array(ORDERED_ALLSLOPES_ALTITUDES) == self.altitude
        orientation_mask = np.array(ORDERED_ALLSLOPES_ORIENTATIONS) == self.orientation
        slope_mask = np.array(ORDERED_ALLSLOPES_SLOPES) == self.slope
        allslopes_mask = altitude_mask & orientation_mask & slope_mask
        # Exclude all the data corresponding to the 24th massif
        massif_24_mask = np.array(ORDERED_ALLSLOPES_MASSIFNUM) == 30
        return allslopes_mask & ~massif_24_mask

    def flat_mask(self):
        if self.french_region is FrenchRegion.alps:
            altitude_mask = ZS_INT_MASK == self.altitude
        elif self.french_region is FrenchRegion.pyrenees:
            altitude_mask = ZS_INT_MASK_PYRENNES == self.altitude
            raise ValueError('{}'.format(self.french_region))
        assert np.sum(altitude_mask) == len(self.altitude_to_massif_names[self.altitude])
        return altitude_mask

    """ Path properties """

    def variable_name(self):
        return self.variable_class.NAME + ' ({})'.format(self.variable_unit)

    def variable_unit(self):
        return self.variable_class.UNIT

    (written as object attributes/methods for simplicity)

    """ Path properties """

    def data_path(self) -> str:
        return DATA_PATH

    def map_full_path(self) -> str:
    def study_full_path(self) -> str:
        assert self.model_name in ['Safran', 'Crocus']
        study_folder = 'meteo' if self.model_name is 'Safran' else 'pro'
        return op.join(self.reanalysis_path, study_folder)

    def reanalysis_path(self):
        reanalysis_folder = self.reanalysisfolder
        assert len(reanalysis_folder) > 0, 'please specify the name of reanalysis folder in configuration.py'
        return op.join(self.data_path, reanalysis_folder)

    def reanalysisfolder(self):
        if self.french_region is FrenchRegion.alps:
            if self.has_orientation:
                reanalysis_folder = ALPS_ALLSLOPES_FOLDER
                reanalysis_folder = ALPS_FLAT_FOLDER
        elif self.french_region is FrenchRegion.pyrenees and not self.has_orientation:
            reanalysis_folder = PYRENEES_FLAT_FOLDER
            raise ValueError(
                'french_region = {}, has_orientation = {}'.format(self.french_region, self.has_orientation))
        return reanalysis_folder

    def YEAR_MIN(self):
        return 1959

    def YEAR_MAX(self):
        nb_files_list = [len(os.listdir(op.join(self.reanalysis_path, model_name))) for model_name in ['meteo', 'pro']]
        nb_files_list = [n for n in nb_files_list if n > 0]
        assert len(nb_files_list) > 0, 'please download some files for {}'.format(self.reanalysisfolder)
        if len(nb_files_list) == 2:
            assert nb_files_list[0] == nb_files_list[1]
        return self.YEAR_MIN + nb_files_list[0]

    def dbf_filename(self) -> str:
        if self.french_region is FrenchRegion.alps:
            return 'massifs_alpes'
        elif self.french_region is FrenchRegion.pyrenees:
            return 'massifs_pyrenees'
            raise ValueError('{}'.format(self.french_region))

    def has_orientation(self):
        return self.orientation is not None

    def season_name(self):
        return season_to_str(self.season)

    """  Spatial properties """

    def massif_name_to_massif_id(self):
        return {name: i for i, name in enumerate(self.study_massif_names)}

    def dbf_filename(self):
        if self.french_region is FrenchRegion.alps:
            return 'massifs_alpes'
            raise NotImplementedError

    def all_massif_names(self) -> List[str]:
        Pour l'identification des massifs, le numéro de la variable massif_num correspond à celui de l'attribut num_opp
        if ALPS_FLAT_FOLDER in self.reanalysis_path or ALPS_ALLSLOPES_FOLDER in self.reanalysis_path:
            french_region = FrenchRegion.alps
            key = 'num_opp'
            french_region = FrenchRegion.pyrenees
            key = 'massif_num'

        metadata_path = op.join(self.data_path, 'metadata')
        dbf = Dbf5(op.join(metadata_path, '{}.dbf'.format(self.dbf_filename)))
        df = dbf.to_dataframe().copy()  # type: pd.DataFrame
        # Important part (for the alps & pyrenees all data is order from the smaller massif number to the bigger)
        df.sort_values(by=key, inplace=True)
        all_massif_names = list(df['nom'])
        # Correct a massif name
        if french_region is FrenchRegion.alps:
            all_massif_names[all_massif_names.index('Beaufortin')] = 'Beaufortain'
        return all_massif_names

    def massif_name_to_altitudes(self) -> Dict[str, List[int]]:
        zs = ZS_INT_23 if self.french_region is FrenchRegion.alps else ZS_INT_MASK_PYRENNES_LIST
        s = zs + [0]
        zs_list = []
        zs_all_list = []
        for a, b in zip(s[:-1], s[1:]):
            if a > b:
                zs_list = []
        all_massif_names = self.all_massif_names
        return OrderedDict(zip(all_massif_names, zs_all_list))

    def altitude_to_massif_names(self) -> Dict[int, List[str]]:
        altitude_to_massif_names = {altitude: [] for altitude in ALTITUDES}
        for massif_name in self.massif_name_to_altitudes.keys():
            for altitude in self.massif_name_to_altitudes[massif_name]:
        # massif_names are ordered in the same way as all_massif_names
        return altitude_to_massif_names

    def csv_file(self):
        if self.french_region is FrenchRegion.alps:
            return 'massifsalpes.csv'
            raise NotImplementedError