Commit 70a70fee authored by Cresson Remi's avatar Cresson Remi
Browse files

FIX: replace keras.layers.Concatenate() with tf.concat

1 merge request!10Update Dockerfile
Pipeline #49335 failed with stages
in 93 minutes and 49 seconds
Showing with 49 additions and 36 deletions
+49 -36
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from decloud.models.crga_os1_base import crga_os1_base from decloud.models.crga_os1_base import crga_os1_base
from tensorflow import concat
class crga_os1_unet(crga_os1_base): class crga_os1_unet(crga_os1_base):
...@@ -57,14 +58,14 @@ class crga_os1_unet(crga_os1_base): ...@@ -57,14 +58,14 @@ class crga_os1_unet(crga_os1_base):
if input_image == "current": if input_image == "current":
net = conv1_s1(input_dict[input_image]) # 256 net = conv1_s1(input_dict[input_image]) # 256
else: else:
net = layers.concatenate(input_dict[input_image], axis=-1) net = concat(input_dict[input_image], axis=-1)
net = conv1_s1s2(net) # 256 net = conv1_s1s2(net) # 256
features[1].append(net) features[1].append(net)
net = conv2(net) # 128 net = conv2(net) # 128
if self.has_dem(): if self.has_dem():
net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY]) net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
net = layers.concatenate([net, net_dem], axis=-1) net = concat([net, net_dem], axis=-1)
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
features[4].append(net) features[4].append(net)
...@@ -79,7 +80,7 @@ class crga_os1_unet(crga_os1_base): ...@@ -79,7 +80,7 @@ class crga_os1_unet(crga_os1_base):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from decloud.models.crga_os1_base_all_bands import crga_os1_base_all_bands from decloud.models.crga_os1_base_all_bands import crga_os1_base_all_bands
from tensorflow import concat
class crga_os1_unet_all_bands(crga_os1_base_all_bands): class crga_os1_unet_all_bands(crga_os1_base_all_bands):
...@@ -69,7 +70,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands): ...@@ -69,7 +70,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
features[1].append(net_10m) features[1].append(net_10m)
net = conv2(net_10m) # 128 net = conv2(net_10m) # 128
else: # for post & ante, the is s1, s2 and s2_20m else: # for post & ante, the is s1, s2 and s2_20m
net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1) net_10m = concat(input_dict[input_image][:2], axis=-1)
net_10m = conv1_s1s2(net_10m) # 256 net_10m = conv1_s1s2(net_10m) # 256
features[1].append(net_10m) features[1].append(net_10m)
net_10m = conv2(net_10m) # 128 net_10m = conv2(net_10m) # 128
...@@ -77,7 +78,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands): ...@@ -77,7 +78,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
features_20m = [net_10m, net_20m] features_20m = [net_10m, net_20m]
if self.has_dem(): if self.has_dem():
features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY])) features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
net = layers.concatenate(features_20m, axis=-1) net = concat(features_20m, axis=-1)
net = conv2_20m(net) # 128 net = conv2_20m(net) # 128
features[2].append(net) features[2].append(net)
...@@ -94,7 +95,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands): ...@@ -94,7 +95,7 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
...@@ -114,6 +115,6 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands): ...@@ -114,6 +115,6 @@ class crga_os1_unet_all_bands(crga_os1_base_all_bands):
# 10m-resampled stack that will be the output for inference (not used for training) # 10m-resampled stack that will be the output for inference (not used for training)
s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out) s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1) s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
return {"s2_t": s2_out, "s2_20m_t": s2_20m_out, 's2_all_bands_estim': s2_all_bands} return {"s2_t": s2_out, "s2_20m_t": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.crga_os2_base import crga_os2_base from decloud.models.crga_os2_base import crga_os2_base
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class crga_os2_david(crga_os2_base): class crga_os2_david(crga_os2_base):
...@@ -50,16 +51,16 @@ class crga_os2_david(crga_os2_base): ...@@ -50,16 +51,16 @@ class crga_os2_david(crga_os2_base):
deconv2 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv2_bn_relu", padding="same") deconv2 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv2_bn_relu", padding="same")
conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same") conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same")
for input_image in input_dict: for input_image in input_dict:
net = layers.concatenate(input_dict[input_image], axis=-1) net = concat(input_dict[input_image], axis=-1)
net = conv1(net) # 256 net = conv1(net) # 256
net = conv2(net) # 128 net = conv2(net) # 128
if self.has_dem(): if self.has_dem():
net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY]) net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
net = layers.concatenate([net, net_dem], axis=-1) net = concat([net, net_dem], axis=-1)
net = conv3(net) # 64 net = conv3(net) # 64
features.append(net) features.append(net)
net = layers.concatenate(features, axis=-1) net = concat(features, axis=-1)
net = deconv1(net) # 128 net = deconv1(net) # 128
net = deconv2(net) # 256 net = deconv2(net) # 256
s2_out = conv4(net) # 256 s2_out = conv4(net) # 256
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class crga_os2_david_all_bands(crga_os2_base_all_bands): class crga_os2_david_all_bands(crga_os2_base_all_bands):
...@@ -56,7 +57,7 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands): ...@@ -56,7 +57,7 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same") conv4 = layers.Conv2D(4, 5, 1, activation='relu', name="s2_estim", padding="same")
conv4_20m = layers.Conv2D(6, 3, 1, activation='relu', name="s2_20m_estim", padding="same") conv4_20m = layers.Conv2D(6, 3, 1, activation='relu', name="s2_20m_estim", padding="same")
for input_image in input_dict: for input_image in input_dict:
net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1) net_10m = concat(input_dict[input_image][:2], axis=-1)
net_10m = conv1(net_10m) # 256 net_10m = conv1(net_10m) # 256
net_10m = conv2(net_10m) # 128 net_10m = conv2(net_10m) # 128
net_20m = conv1_20m(input_dict[input_image][2]) # 128 net_20m = conv1_20m(input_dict[input_image][2]) # 128
...@@ -64,11 +65,11 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands): ...@@ -64,11 +65,11 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
features_20m = [net_10m, net_20m] features_20m = [net_10m, net_20m]
if self.has_dem(): if self.has_dem():
features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY])) features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
net = layers.concatenate(features_20m, axis=-1) # 128 net = concat(features_20m, axis=-1) # 128
net = conv3(net) # 64 net = conv3(net) # 64
features.append(net) features.append(net)
net = layers.concatenate(features, axis=-1) net = concat(features, axis=-1)
net = deconv1(net) # 128 net = deconv1(net) # 128
net_10m = deconv2(net) # 256 net_10m = deconv2(net) # 256
net_20m = deconv2_20m(net) # 128 net_20m = deconv2_20m(net) # 128
...@@ -78,6 +79,6 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands): ...@@ -78,6 +79,6 @@ class crga_os2_david_all_bands(crga_os2_base_all_bands):
# 10m-resampled stack that will be the output for inference (not used for training) # 10m-resampled stack that will be the output for inference (not used for training)
s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out) s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1) s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands} return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers, initializers from tensorflow.keras import layers, initializers
from decloud.models.crga_os2_base import crga_os2_base from decloud.models.crga_os2_base import crga_os2_base
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class crga_os2_unet(crga_os2_base): class crga_os2_unet(crga_os2_base):
...@@ -67,13 +68,13 @@ class crga_os2_unet(crga_os2_base): ...@@ -67,13 +68,13 @@ class crga_os2_unet(crga_os2_base):
kernel_initializer=initializers.VarianceScaling()) kernel_initializer=initializers.VarianceScaling())
for input_image in input_dict: for input_image in input_dict:
net = layers.concatenate(input_dict[input_image], axis=-1) net = concat(input_dict[input_image], axis=-1)
net = conv1(net) # 256 net = conv1(net) # 256
features[1].append(net) features[1].append(net)
net = conv2(net) # 128 net = conv2(net) # 128
if self.has_dem(): if self.has_dem():
net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY]) net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
net = layers.concatenate([net, net_dem], axis=-1) net = concat([net, net_dem], axis=-1)
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
features[4].append(net) features[4].append(net)
...@@ -88,7 +89,7 @@ class crga_os2_unet(crga_os2_base): ...@@ -88,7 +89,7 @@ class crga_os2_unet(crga_os2_base):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands from decloud.models.crga_os2_base_all_bands import crga_os2_base_all_bands
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class crga_os2_unet_all_bands(crga_os2_base_all_bands): class crga_os2_unet_all_bands(crga_os2_base_all_bands):
...@@ -65,7 +66,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands): ...@@ -65,7 +66,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
# The network # The network
features = {factor: [] for factor in [1, 2, 4, 8, 16, 32]} features = {factor: [] for factor in [1, 2, 4, 8, 16, 32]}
for input_image in input_dict: for input_image in input_dict:
net_10m = layers.concatenate(input_dict[input_image][:2], axis=-1) net_10m = concat(input_dict[input_image][:2], axis=-1)
net_10m = conv1(net_10m) # 256 net_10m = conv1(net_10m) # 256
features[1].append(net_10m) features[1].append(net_10m)
net_10m = conv2(net_10m) # 128 net_10m = conv2(net_10m) # 128
...@@ -74,7 +75,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands): ...@@ -74,7 +75,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
features_20m = [net_10m, net_20m] features_20m = [net_10m, net_20m]
if self.has_dem(): if self.has_dem():
features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY])) features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
net = layers.concatenate(features_20m, axis=-1) # 128 net = concat(features_20m, axis=-1) # 128
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
features[4].append(net) features[4].append(net)
...@@ -89,7 +90,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands): ...@@ -89,7 +90,7 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
...@@ -109,6 +110,6 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands): ...@@ -109,6 +110,6 @@ class crga_os2_unet_all_bands(crga_os2_base_all_bands):
# 10m-resampled stack that will be the output for inference (not used for training) # 10m-resampled stack that will be the output for inference (not used for training)
s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out) s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1) s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands} return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
...@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""Implementation of the Meraner et al. original network""" """Implementation of the Meraner et al. original network"""
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.model import Model from decloud.models.model import Model
from tensorflow import concat
class meraner_original(Model): class meraner_original(Model):
...@@ -50,7 +51,7 @@ class meraner_original(Model): ...@@ -50,7 +51,7 @@ class meraner_original(Model):
# The network # The network
conv1 = layers.Conv2D(resblocks_dim, 3, 1, activation='relu', name="conv1_relu", padding="same") conv1 = layers.Conv2D(resblocks_dim, 3, 1, activation='relu', name="conv1_relu", padding="same")
net = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1) net = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
net = conv1(net) net = conv1(net)
for i in range(n_resblocks): for i in range(n_resblocks):
net = _resblock(net, i) net = _resblock(net, i)
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.model import Model from decloud.models.model import Model
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class meraner_unet(Model): class meraner_unet(Model):
...@@ -56,13 +57,13 @@ class meraner_unet(Model): ...@@ -56,13 +57,13 @@ class meraner_unet(Model):
deconv5 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv5_bn_relu", padding="same") deconv5 = layers.Conv2DTranspose(64, 3, 2, activation='relu', name="deconv5_bn_relu", padding="same")
conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same") conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same")
net = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1) net = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
net = conv1(net) # 256 net = conv1(net) # 256
features[1].append(net) features[1].append(net)
net = conv2(net) # 128 net = conv2(net) # 128
if self.has_dem(): if self.has_dem():
net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY]) net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
net = layers.concatenate([net, net_dem], axis=-1) net = concat([net, net_dem], axis=-1)
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
features[4].append(net) features[4].append(net)
...@@ -76,7 +77,7 @@ class meraner_unet(Model): ...@@ -76,7 +77,7 @@ class meraner_unet(Model):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = deconv1(net) # 16 net = deconv1(net) # 16
net = _combine(factor=16, x=net) net = _combine(factor=16, x=net)
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from decloud.models.model import Model from decloud.models.model import Model
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow import concat
class meraner_unet_all_bands(Model): class meraner_unet_all_bands(Model):
...@@ -58,7 +59,7 @@ class meraner_unet_all_bands(Model): ...@@ -58,7 +59,7 @@ class meraner_unet_all_bands(Model):
conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same") conv_final = layers.Conv2D(4, 5, 1, name="s2_estim", padding="same")
conv_20m_final = layers.Conv2D(6, 3, 1, name="s2_20m_estim", padding="same") conv_20m_final = layers.Conv2D(6, 3, 1, name="s2_20m_estim", padding="same")
net_10m = layers.concatenate([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1) net_10m = concat([normalized_inputs["s1_t"], normalized_inputs["s2_t"]], axis=-1)
net_10m = conv1(net_10m) # 256 net_10m = conv1(net_10m) # 256
features[1].append(net_10m) features[1].append(net_10m)
net_10m = conv2(net_10m) # 128 net_10m = conv2(net_10m) # 128
...@@ -67,7 +68,7 @@ class meraner_unet_all_bands(Model): ...@@ -67,7 +68,7 @@ class meraner_unet_all_bands(Model):
features_20m = [net_10m, net_20m] features_20m = [net_10m, net_20m]
if self.has_dem(): if self.has_dem():
features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY])) features_20m.append(conv1_dem(normalized_inputs[constants.DEM_KEY]))
net = layers.concatenate(features_20m, axis=-1) # 128 net = concat(features_20m, axis=-1) # 128
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
...@@ -83,7 +84,7 @@ class meraner_unet_all_bands(Model): ...@@ -83,7 +84,7 @@ class meraner_unet_all_bands(Model):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
...@@ -102,6 +103,6 @@ class meraner_unet_all_bands(Model): ...@@ -102,6 +103,6 @@ class meraner_unet_all_bands(Model):
# 10m-resampled stack that will be the output for inference (not used for training) # 10m-resampled stack that will be the output for inference (not used for training)
s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out) s2_20m_resampled = layers.UpSampling2D(size=(2, 2))(s2_20m_out)
s2_all_bands = layers.concatenate([s2_out, s2_20m_resampled], axis=-1) s2_all_bands = concat([s2_out, s2_20m_resampled], axis=-1)
return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands} return {"s2_target": s2_out, "s2_20m_target": s2_20m_out, 's2_all_bands_estim': s2_all_bands}
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.model import Model from decloud.models.model import Model
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class monthly_synthesis_6_s2_images(Model): class monthly_synthesis_6_s2_images(Model):
...@@ -55,7 +56,7 @@ class monthly_synthesis_6_s2_images(Model): ...@@ -55,7 +56,7 @@ class monthly_synthesis_6_s2_images(Model):
features[1].append(net) features[1].append(net)
net = conv2(net) # 128 net = conv2(net) # 128
if self.has_dem(): if self.has_dem():
net = layers.concatenate([net, normalized_inputs[constants.DEM_KEY]], axis=-1) net = concat([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
features[4].append(net) features[4].append(net)
...@@ -70,7 +71,7 @@ class monthly_synthesis_6_s2_images(Model): ...@@ -70,7 +71,7 @@ class monthly_synthesis_6_s2_images(Model):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
......
...@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""David model implementation (monthly synthesis of 6 optical images)""" """David model implementation (monthly synthesis of 6 optical images)"""
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.model import Model from decloud.models.model import Model
from tensorflow import concat
class monthly_synthesis_6_s2_images_david(Model): class monthly_synthesis_6_s2_images_david(Model):
...@@ -48,7 +49,7 @@ class monthly_synthesis_6_s2_images_david(Model): ...@@ -48,7 +49,7 @@ class monthly_synthesis_6_s2_images_david(Model):
net = conv3(net) # 64 net = conv3(net) # 64
features.append(net) features.append(net)
net = layers.concatenate(features, axis=-1) net = concat(features, axis=-1)
net = deconv1(net) # 128 net = deconv1(net) # 128
net = deconv2(net) # 256 net = deconv2(net) # 256
s2_out = conv4(net) # 256 s2_out = conv4(net) # 256
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.model import Model from decloud.models.model import Model
import decloud.preprocessing.constants as constants import decloud.preprocessing.constants as constants
from tensorflow import concat
class monthly_synthesis_6_s2s1_images(Model): class monthly_synthesis_6_s2s1_images(Model):
...@@ -64,7 +65,7 @@ class monthly_synthesis_6_s2s1_images(Model): ...@@ -64,7 +65,7 @@ class monthly_synthesis_6_s2s1_images(Model):
features[1].append(net) features[1].append(net)
net = conv2_s2(net) net = conv2_s2(net)
if self.has_dem(): if self.has_dem():
net = layers.concatenate([net, normalized_inputs[constants.DEM_KEY]], axis=-1) net = concat([net, normalized_inputs[constants.DEM_KEY]], axis=-1)
features[2].append(net) features[2].append(net)
net = conv3(net) # 64 net = conv3(net) # 64
features[4].append(net) features[4].append(net)
...@@ -79,7 +80,7 @@ class monthly_synthesis_6_s2s1_images(Model): ...@@ -79,7 +80,7 @@ class monthly_synthesis_6_s2s1_images(Model):
def _combine(factor, x=None): def _combine(factor, x=None):
if x is not None: if x is not None:
features[factor].append(x) features[factor].append(x)
return layers.concatenate(features[factor], axis=-1) return concat(features[factor], axis=-1)
net = _combine(factor=32) net = _combine(factor=32)
net = deconv1(net) # 16 net = deconv1(net) # 16
......
...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. ...@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from tensorflow.keras import layers from tensorflow.keras import layers
from decloud.models.model import Model from decloud.models.model import Model
from decloud.preprocessing import constants from decloud.preprocessing import constants
from tensorflow import concat
class monthly_synthesis_6_s2s1_images_david(Model): class monthly_synthesis_6_s2s1_images_david(Model):
...@@ -56,11 +57,11 @@ class monthly_synthesis_6_s2s1_images_david(Model): ...@@ -56,11 +57,11 @@ class monthly_synthesis_6_s2s1_images_david(Model):
net = conv2(net) # 128 net = conv2(net) # 128
if self.has_dem(): if self.has_dem():
net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY]) net_dem = conv1_dem(normalized_inputs[constants.DEM_KEY])
net = layers.concatenate([net, net_dem], axis=-1) net = concat([net, net_dem], axis=-1)
net = conv3(net) # 64 net = conv3(net) # 64
features.append(net) features.append(net)
net = layers.concatenate(features, axis=-1) net = concat(features, axis=-1)
net = deconv1(net) # 128 net = deconv1(net) # 128
net = deconv2(net) # 256 net = deconv2(net) # 256
s2_out = conv4(net) # 256 s2_out = conv4(net) # 256
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment