Commit 3c346ef1 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[ABSTRACT STUDY] add more refactoring to abstract study

parent 772ba818
No related merge requests found
Showing with 24 additions and 18 deletions
+24 -18
...@@ -35,6 +35,11 @@ with redirect_stdout(f): ...@@ -35,6 +35,11 @@ with redirect_stdout(f):
class AbstractStudy(object): class AbstractStudy(object):
""" """
A Study is defined by:
- a variable class that correspond to the meteorogical quantity of interest
- an altitude of interest
- a start and a end year
Les fichiers netcdf de SAFRAN et CROCUS sont autodocumentés (on peut les comprendre avec ncdump -h notamment). Les fichiers netcdf de SAFRAN et CROCUS sont autodocumentés (on peut les comprendre avec ncdump -h notamment).
""" """
REANALYSIS_FOLDER = 'alp_flat/reanalysis' REANALYSIS_FOLDER = 'alp_flat/reanalysis'
...@@ -96,13 +101,10 @@ class AbstractStudy(object): ...@@ -96,13 +101,10 @@ class AbstractStudy(object):
@property @property
def _year_to_daily_time_serie_array(self) -> OrderedDict: def _year_to_daily_time_serie_array(self) -> OrderedDict:
# Map each year to a matrix of size 365-nb_days_consecutive+1 x nb_massifs # Map each year to a matrix of size 365-nb_days_consecutive+1 x nb_massifs
variables = [self.instantiate_variable_object(variable_array) for variable_array in
self.year_to_variable_array.values()]
year_to_variable = dict(zip(self.ordered_years, variables))
year_to_daily_time_serie_array = OrderedDict() year_to_daily_time_serie_array = OrderedDict()
for year in self.ordered_years: for year in self.ordered_years:
# Check daily data # Check daily data
daily_time_serie = year_to_variable[year].daily_time_serie_array daily_time_serie = self.year_to_variable_object[year].daily_time_serie_array
assert daily_time_serie.shape[0] in [365, 366] assert daily_time_serie.shape[0] in [365, 366]
assert daily_time_serie.shape[1] == len(ZS_INT_MASK) assert daily_time_serie.shape[1] == len(ZS_INT_MASK)
# Filter only the data corresponding to the altitude of interest # Filter only the data corresponding to the altitude of interest
...@@ -110,29 +112,31 @@ class AbstractStudy(object): ...@@ -110,29 +112,31 @@ class AbstractStudy(object):
year_to_daily_time_serie_array[year] = daily_time_serie year_to_daily_time_serie_array[year] = daily_time_serie
return year_to_daily_time_serie_array return year_to_daily_time_serie_array
def instantiate_variable_object(self, variable_array) -> AbstractVariable:
return self.variable_class(variable_array)
""" Load Variables and Datasets """ """ Load Variables and Datasets """
@cached_property @cached_property
def year_to_variable_array(self) -> OrderedDict: def year_to_variable_object(self) -> OrderedDict:
# Map each year to the variable array # Map each year to the variable array
path_files, ordered_years = self.ordered_years_and_path_files path_files, ordered_years = self.ordered_years_and_path_files
if self.multiprocessing: if self.multiprocessing:
with Pool(NB_CORES) as p: with Pool(NB_CORES) as p:
variables = p.map(self.load_variables, path_files) variables = p.map(self.load_variable_object, path_files)
else: else:
variables = [self.load_variables(path_file) for path_file in path_files] variables = [self.load_variable_object(path_file) for path_file in path_files]
return OrderedDict(zip(ordered_years, variables)) return OrderedDict(zip(ordered_years, variables))
def load_variables(self, path_file): def instantiate_variable_object(self, variable_array) -> AbstractVariable:
return self.variable_class(variable_array)
def load_variable_array(self, dataset):
return np.array(dataset.variables[self.load_keyword()])
def load_variable_object(self, path_file):
dataset = Dataset(path_file) dataset = Dataset(path_file)
keyword = self.load_keyword() variable_array = self.load_variable_array(dataset)
if isinstance(keyword, str): return self.instantiate_variable_object(variable_array)
return np.array(dataset.variables[keyword])
else:
return [np.array(dataset.variables[k]) for k in keyword]
def load_keyword(self): def load_keyword(self):
return self.variable_class.keyword() return self.variable_class.keyword()
......
...@@ -42,6 +42,9 @@ class SafranTotalPrecip(CumulatedStudy, Safran): ...@@ -42,6 +42,9 @@ class SafranTotalPrecip(CumulatedStudy, Safran):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(SafranTotalPrecipVariable, *args, **kwargs) super().__init__(SafranTotalPrecipVariable, *args, **kwargs)
def load_variable_array(self, dataset):
return [np.array(dataset.variables[k]) for k in self.load_keyword()]
def instantiate_variable_object(self, variable_array) -> AbstractVariable: def instantiate_variable_object(self, variable_array) -> AbstractVariable:
variable_array_snowfall, variable_array_rainfall = variable_array variable_array_snowfall, variable_array_rainfall = variable_array
return self.variable_class(variable_array_snowfall, variable_array_rainfall, self.nb_consecutive_days) return self.variable_class(variable_array_snowfall, variable_array_rainfall, self.nb_consecutive_days)
......
...@@ -54,7 +54,7 @@ def full_quantity_altitude_hypercube(): ...@@ -54,7 +54,7 @@ def full_quantity_altitude_hypercube():
def fast_altitude_hypercube(): def fast_altitude_hypercube():
save_to_file = False save_to_file = True
only_first_one = False only_first_one = False
fast = True fast = True
altitudes = ALL_ALTITUDES[2:4] altitudes = ALL_ALTITUDES[2:4]
......
...@@ -33,8 +33,7 @@ class TestSCMAllStudy(unittest.TestCase): ...@@ -33,8 +33,7 @@ class TestSCMAllStudy(unittest.TestCase):
altitudes=sample(set(ALL_ALTITUDES), k=nb_sample), nb_days=nb_days): altitudes=sample(set(ALL_ALTITUDES), k=nb_sample), nb_days=nb_days):
self.assertTrue('day' in study.variable_name) self.assertTrue('day' in study.variable_name)
first_path_file = study.ordered_years_and_path_files[0][0] first_path_file = study.ordered_years_and_path_files[0][0]
variable_array = study.load_variables(path_file=first_path_file) variable_object = study.load_variable_object(path_file=first_path_file)
variable_object = study.instantiate_variable_object(variable_array)
self.assertEqual((365, 263), variable_object.daily_time_serie_array.shape, self.assertEqual((365, 263), variable_object.daily_time_serie_array.shape,
msg='{} days for type {}'.format(nb_days, get_display_name_from_object_type(type(variable_object)))) msg='{} days for type {}'.format(nb_days, get_display_name_from_object_type(type(variable_object))))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment