Commit 8f4e79a4 authored by Ienco Dino's avatar Ienco Dino
Browse files

Update BaseModels.py

parent 511eefe8
No related merge requests found
Showing with 19 additions and 7 deletions
+19 -7
...@@ -26,19 +26,19 @@ class CNN1D(tf.keras.Model): ...@@ -26,19 +26,19 @@ class CNN1D(tf.keras.Model):
@tf.function @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
conv1 = self.conv1(inputs) conv1 = self.conv1(inputs)
conv1 = self.bn1(conv1) #conv1 = self.bn1(conv1)#, training=training)
conv1 = self.do1(conv1, training=training) conv1 = self.do1(conv1, training=training)
conv2 = self.conv2(conv1) conv2 = self.conv2(conv1)
conv2 = self.bn2(conv2) #conv2 = self.bn2(conv2)#, training=training)
conv2 = self.do2(conv2, training=training) conv2 = self.do2(conv2, training=training)
conv3 = self.conv3(conv2) conv3 = self.conv3(conv2)
conv3 = self.bn3(conv3) #conv3 = self.bn3(conv3)#, training=training)
conv3 = self.do3(conv3, training=training) conv3 = self.do3(conv3, training=training)
conv4 = self.conv4(conv3) conv4 = self.conv4(conv3)
conv4 = self.bn4(conv4) #conv4 = self.bn4(conv4)#, training=training)
conv4 = self.do4(conv4, training=training) conv4 = self.do4(conv4, training=training)
pool = self.pool(conv4) pool = self.pool(conv4)
...@@ -51,17 +51,29 @@ class TwoBranchModel(tf.keras.Model): ...@@ -51,17 +51,29 @@ class TwoBranchModel(tf.keras.Model):
self.PixelBranch = CNN1D(n_filters, suffix, dropout_rate=dropout_rate) self.PixelBranch = CNN1D(n_filters, suffix, dropout_rate=dropout_rate)
self.ObjBranch = CNN1D(n_filters, suffix, dropout_rate=dropout_rate) self.ObjBranch = CNN1D(n_filters, suffix, dropout_rate=dropout_rate)
self.dense1 = tfk.layers.Dense(512, activation='relu') self.dense1 = tfk.layers.Dense(256, activation='relu')
self.dense2 = tfk.layers.Dense(512, activation='relu') self.drop1 = tfk.layers.Dropout(rate=dropout_rate)
self.bn1 = tfk.layers.BatchNormalization()
self.dense2 = tfk.layers.Dense(256, activation='relu')
self.classif = tfk.layers.Dense(nb_classes, activation='softmax') self.classif = tfk.layers.Dense(nb_classes, activation='softmax')
self.classifPix = tfk.layers.Dense(nb_classes, activation='softmax')
self.classifObj = tfk.layers.Dense(nb_classes, activation='softmax')
@tf.function @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
pixel_inputs, obj_inputs = inputs pixel_inputs, obj_inputs = inputs
branchP = self.PixelBranch(pixel_inputs, training=training) branchP = self.PixelBranch(pixel_inputs, training=training)
branchO = self.ObjBranch(obj_inputs, training=training) branchO = self.ObjBranch(obj_inputs, training=training)
classifP = self.classifPix(branchP)
classifO = self.classifObj(branchO)
#feat = branchP + branchO
feat = tf.concat([branchP, branchO], axis=1) feat = tf.concat([branchP, branchO], axis=1)
output = self.dense1(feat) output = self.dense1(feat)
output = self.bn1(output, training=training)
output = self.drop1(output, training=training)
output = self.dense2(output) output = self.dense2(output)
classif = self.classif(output) classif = self.classif(output)
return classif return classif, classifP, classifO
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment