Forked from HYCAR-Hydro / airGR
Source project has a limited visibility.
BaseModels.py 3.57 KiB
import tensorflow.keras as tfk
import tensorflow as tf

class CNN1D(tf.keras.Model):
    def __init__(self, n_filters, suffix, dropout_rate = 0.0, hidden_activation='relu', name='CNN1D', **kwargs):
        super(CNN1D, self).__init__(name=name, **kwargs)
        self.conv1 = tfk.layers.Conv1D(filters=n_filters, kernel_size=3, padding='valid', name="conv1_"+suffix, activation="relu")
        self.bn1 = tfk.layers.BatchNormalization(name="bn1_"+suffix)
        self.do1 = tfk.layers.Dropout(rate=dropout_rate, name="dropOut1_"+suffix)

        self.conv2 = tfk.layers.Conv1D(filters=n_filters, kernel_size=3, padding='valid', name="conv2_"+suffix, activation="relu")
        self.bn2 = tfk.layers.BatchNormalization(name="bn2_"+suffix)
        self.do2 = tfk.layers.Dropout(rate=dropout_rate, name="dropOut2_"+suffix)

        self.conv3 = tfk.layers.Conv1D(filters=n_filters*2, kernel_size=3, padding='valid', name="conv3_"+suffix, activation="relu")
        self.bn3 = tfk.layers.BatchNormalization(name="bn3_"+suffix)
        self.do3 = tfk.layers.Dropout(rate=dropout_rate, name="dropOut3_"+suffix)

        self.conv4 = tfk.layers.Conv1D(filters=n_filters*2, kernel_size=1, padding='valid', name="conv4_"+suffix, activation="relu")
        self.bn4 = tfk.layers.BatchNormalization(name="bn4_"+suffix)
        self.do4 = tfk.layers.Dropout(rate=dropout_rate, name="dropOut4_"+suffix)

        self.pool = tfk.layers.GlobalAveragePooling1D()


    @tf.function
    def call(self, inputs, training=False):
        conv1 = self.conv1(inputs)
        #conv1 = self.bn1(conv1)#, training=training)
        conv1 = self.do1(conv1, training=training)

        conv2 = self.conv2(conv1)
        #conv2 = self.bn2(conv2)#, training=training)
        conv2 = self.do2(conv2, training=training)

        conv3 = self.conv3(conv2)
        #conv3 = self.bn3(conv3)#, training=training)
        conv3 = self.do3(conv3, training=training)

        conv4 = self.conv4(conv3)
        #conv4 = self.bn4(conv4)#, training=training)
        conv4 = self.do4(conv4, training=training)

        pool = self.pool(conv4)
        return pool

class TwoBranchModel(tf.keras.Model):
    def __init__(self, n_filters, suffix, nb_classes, dropout_rate = 0.0, hidden_activation='relu', name='TwoBranchModel', **kwargs):
        super(TwoBranchModel, self).__init__(name=name, **kwargs)

        self.PixelBranch = CNN1D(n_filters, suffix, dropout_rate=dropout_rate)
        self.ObjBranch = CNN1D(n_filters, suffix, dropout_rate=dropout_rate)

        self.dense1 = tfk.layers.Dense(256, activation='relu')
        self.drop1 = tfk.layers.Dropout(rate=dropout_rate)
        self.bn1 = tfk.layers.BatchNormalization()
        self.dense2 = tfk.layers.Dense(256, activation='relu')
        self.classif = tfk.layers.Dense(nb_classes, activation='softmax')

        self.classifPix = tfk.layers.Dense(nb_classes, activation='softmax')
        self.classifObj = tfk.layers.Dense(nb_classes, activation='softmax')

    @tf.function
    def call(self, inputs, training=False):
        pixel_inputs, obj_inputs = inputs
        branchP = self.PixelBranch(pixel_inputs, training=training)
        branchO = self.ObjBranch(obj_inputs, training=training)

        classifP = self.classifPix(branchP)
        classifO = self.classifObj(branchO)

        #feat = branchP + branchO
        feat = tf.concat([branchP, branchO], axis=1)
        output = self.dense1(feat)
        output = self.bn1(output, training=training)
        output = self.drop1(output, training=training)
        output = self.dense2(output)
        classif = self.classif(output)
        return classif, classifP, classifO