Commit 937c09a0 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[contrasting] fix spatio temporal model with cross terms

parent ee91dd99
No related merge requests found
Showing with 23 additions and 14 deletions
+23 -14
from extreme_fit.distribution.gev.gev_params import GevParams
from extreme_fit.model.margin_model.polynomial_margin_model.polynomial_margin_model import PolynomialMarginModel
from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_polynomial_model import \
AbstractSpatioTemporalPolynomialModel
......@@ -6,7 +7,11 @@ from extreme_fit.model.margin_model.polynomial_margin_model.spatio_temporal_poly
class AbstractAltitudinalModel(AbstractSpatioTemporalPolynomialModel):
def load_margin_function(self, param_name_to_dims=None):
return super().load_margin_function(self.param_name_to_list_dim_and_degree)
return super().load_margin_function(self.param_name_to_list_dim_and_degree_for_margin_function)
@property
def param_name_to_list_dim_and_degree_for_margin_function(self):
return self.param_name_to_list_dim_and_degree
@property
def param_name_to_list_dim_and_degree(self):
......@@ -67,10 +72,13 @@ class NonStationaryAltitudinalLocationQuadraticScaleLinear(AbstractAltitudinalMo
class AbstractAddCrossTermForLocation(AbstractAltitudinalModel):
def load_margin_function(self, param_name_to_dims=None):
@property
def param_name_to_list_dim_and_degree_for_margin_function(self):
d = self.param_name_to_list_dim_and_degree
d[GevParams.LOC] += ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1)
return super().load_margin_function(d)
assert 1 <= len(d[GevParams.LOC]) <= 2
assert self.coordinates.idx_x_coordinates == d[GevParams.LOC][0][0]
d[GevParams.LOC].insert(1, ((self.coordinates.idx_x_coordinates, self.coordinates.idx_temporal_coordinates), 1))
return d
class NonStationaryCrossTermForLocation(AbstractAddCrossTermForLocation, StationaryAltitudinal):
......
......@@ -8,23 +8,23 @@ from projects.altitude_spatial_model.altitudes_fit.one_fold_analysis.altitudes_s
AltitudesStudiesVisualizerForNonStationaryModels
def plot_altitudinal_fit(studies):
def plot_altitudinal_fit(studies, massif_names=None):
visualizer = AltitudesStudiesVisualizerForNonStationaryModels(studies=studies,
model_classes=ALTITUDINAL_MODELS,
massif_names=['Belledonne'],
show=True)
massif_names=massif_names,
show=False)
visualizer.plot_mean()
visualizer.plot_relative_change()
def plot_time_series(studies):
studies.plot_maxima_time_series()
def plot_time_series(studies, massif_names=None):
studies.plot_maxima_time_series(massif_names=massif_names)
def plot_moments(studies):
def plot_moments(studies, massif_names=None):
for std in [True, False][1:]:
for change in [True, False, None]:
studies.plot_mean_maxima_against_altitude(std=std, change=change)
studies.plot_mean_maxima_against_altitude(massif_names=massif_names, std=std, change=change)
def main():
......@@ -34,12 +34,13 @@ def main():
study_classes = [SafranPrecipitation1Day, SafranPrecipitation3Days, SafranPrecipitation5Days,
SafranPrecipitation7Days][:]
study_classes = [SafranPrecipitation1Day, SafranSnowfall1Day, SafranSnowfall3Days, SafranPrecipitation3Days][:1]
massif_names = ['Belledonne']
for study_class in study_classes:
studies = AltitudesStudies(study_class, altitudes, season=Season.winter_extended)
# plot_time_series(studies)
# plot_moments(studies)
plot_altitudinal_fit(studies)
plot_time_series(studies, massif_names)
plot_moments(studies, massif_names)
plot_altitudinal_fit(studies, massif_names)
if __name__ == '__main__':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment