Commit e9a5abea authored by Remi Cresson's avatar Remi Cresson
Browse files

FIX: bug in meraner unet all bands (no gradient for 20m bands)

1 merge request!10Update Dockerfile
Showing with 24 additions and 7 deletions
+24 -7
......@@ -31,12 +31,25 @@ class meraner_unet_all_bands(Model):
"""
Implementation of a variant of the Meraner et al. network (all bands)
"""
def __init__(self, dataset_shapes,
dataset_input_keys=["s1_t", "s2_t", "s2_20m_t",
constants.DEM_KEY],
model_output_keys=["s2_target"]):
super().__init__(dataset_input_keys=dataset_input_keys, model_output_keys=model_output_keys,
dataset_shapes=dataset_shapes)
def __init__(
self,
dataset_shapes,
dataset_input_keys=[
"s1_t",
"s2_t",
"s2_20m_t",
constants.DEM_KEY
],
model_output_keys=[
"s2_target",
"s2_20m_target"
]
):
super().__init__(
dataset_input_keys=dataset_input_keys,
model_output_keys=model_output_keys,
dataset_shapes=dataset_shapes
)
def get_outputs(self, normalized_inputs):
# The network
......@@ -105,4 +118,8 @@ class meraner_unet_all_bands(Model):
s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
return {
"s2_target": s2_out,
"s2_20m_target": s2_20m_out,
"s2_all_bands_estim": s2_all_bands
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment