Commit 04f81275 authored by Le Roux Erwan's avatar Le Roux Erwan
Browse files

[refactor] improve test coverage for dataset.

parent 31ea6f83
No related merge requests found
Showing with 2 additions and 55 deletions
+2 -55
......@@ -43,7 +43,7 @@ class AbstractEstimator(object):
@property
def train_split(self):
return self.dataset.train_split
return self.dataset.slicer.train_split
......
......@@ -158,9 +158,6 @@ class AbstractCoordinates(object):
def coordinates_values(self, split: Split = Split.all, transformed=True) -> np.ndarray:
return self.df_coordinates(split, transformed=transformed).values
def coordinates_index(self, split: Split = Split.all) -> pd.Index:
return self.df_coordinates(split).index
@property
def ind_train_spatial(self) -> pd.Series:
return ind_train_from_s_split(s_split=self.s_split_spatial)
......
......@@ -19,21 +19,6 @@ class AbstractDataset(object):
assert pd.Index.equals(observations.index, coordinates.index), '\n{}\n{}'.format(observations.index, coordinates.index)
self.observations = observations # type: AbstractSpatioTemporalObservations
self.coordinates = coordinates # type: AbstractCoordinates
self.subset_id_to_column_idxs = None # type: Dict[int, List[int]]
@classmethod
def from_csv(cls, csv_path: str):
assert op.exists(csv_path)
df = pd.read_csv(csv_path, index_col=0)
coordinates = AbstractCoordinates.from_df(df)
temporal_maxima = AbstractSpatioTemporalObservations.from_df(df)
return cls(temporal_maxima, coordinates)
def to_csv(self, csv_path: str):
dirname = op.dirname(csv_path)
if not op.exists(dirname):
os.makedirs(dirname)
self.df_dataset.to_csv(csv_path)
@property
def df_dataset(self) -> pd.DataFrame:
......@@ -83,49 +68,13 @@ class AbstractDataset(object):
def coordinates_values(self, split: Split = Split.all) -> np.ndarray:
return self.coordinates.coordinates_values(split=split)
def coordinates_index(self, split: Split = Split.all) -> pd.Index:
return self.coordinates.coordinates_index(split=split)
# Slicer wrapper
@property
def slicer(self) -> AbstractSlicer:
return self.coordinates.slicer
@property
def train_split(self) -> Split:
return self.slicer.train_split
@property
def test_split(self) -> Split:
return self.slicer.test_split
@property
def splits(self) -> List[Split]:
return self.slicer.splits
# Dataset subsets
def create_subsets(self, nb_subsets):
self.subset_id_to_column_idxs = {}
for subset_id in range(nb_subsets):
column_idxs = [idx for idx in range(self.observations.nb_obs) if idx % nb_subsets == subset_id]
self.subset_id_to_column_idxs[subset_id] = column_idxs
# Special methods
def __str__(self) -> str:
return 'coordinates:\n{}\nobservations:\n{}'.format(self.coordinates.__str__(), self.observations.__str__())
def get_subset_dataset(dataset: AbstractDataset, subset_id) -> AbstractDataset:
columns_idxs = dataset.subset_id_to_column_idxs[subset_id]
assert dataset.subset_id_to_column_idxs is not None, 'You need to create subsets'
assert subset_id in dataset.subset_id_to_column_idxs.keys()
subset_dataset = copy.deepcopy(dataset)
observations = subset_dataset.observations
if observations.df_maxima_gev is not None:
observations.df_maxima_gev = observations.df_maxima_gev.iloc[:, columns_idxs]
if observations.df_maxima_frech is not None:
observations.df_maxima_frech = observations.df_maxima_frech.iloc[:, columns_idxs]
return subset_dataset
......@@ -63,6 +63,7 @@ class TestSpatioTemporalDataset(unittest.TestCase):
self.dataset = MarginDataset.from_sampling(nb_obs=nb_obs,
margin_model=smooth_margin_model,
coordinates=self.coordinates)
print(self.dataset.__str__())
def test_spatio_temporal_array_wrt_time(self):
# The test could have been on a given station. But we decided to do it for a given time step.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment