diff --git a/decloud/models/meraner_unet_all_bands.py b/decloud/models/meraner_unet_all_bands.py index 206f16a4ff54f4a2101c4fa8cf8a20102b89f824..76e2389ea0033188bf7d7733af0b80a29b79b515 100644 --- a/decloud/models/meraner_unet_all_bands.py +++ b/decloud/models/meraner_unet_all_bands.py @@ -31,12 +31,25 @@ class meraner_unet_all_bands(Model): """ Implementation of a variant of the Meraner et al. network (all bands) """ - def __init__(self, dataset_shapes, - dataset_input_keys=["s1_t", "s2_t", "s2_20m_t", - constants.DEM_KEY], - model_output_keys=["s2_target"]): - super().__init__(dataset_input_keys=dataset_input_keys, model_output_keys=model_output_keys, - dataset_shapes=dataset_shapes) + def __init__( + self, + dataset_shapes, + dataset_input_keys=[ + "s1_t", + "s2_t", + "s2_20m_t", + constants.DEM_KEY + ], + model_output_keys=[ + "s2_target", + "s2_20m_target" + ] + ): + super().__init__( + dataset_input_keys=dataset_input_keys, + model_output_keys=model_output_keys, + dataset_shapes=dataset_shapes + ) def get_outputs(self, normalized_inputs): # The network @@ -105,4 +118,8 @@ class meraner_unet_all_bands(Model): s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out) s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1) - return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands} + return { + "s2_target": s2_out, + "s2_20m_target": s2_20m_out, + "s2_all_bands_estim": s2_all_bands + }