GUnit.py 2.36 KiB
import tensorflow as tf
from tensorflow.nn.rnn_cell import RNNCell
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
from tensorflow.python.ops import nn_ops
from functions import CNN

_BIAS_G3 = "bias_gate_3"
_BIAS_G2 = "bias_gate_2"
_BIAS_G1 = "bias_gate_1"

_WEIGHTS_G1 = "kernel_gate_1"
_WEIGHTS_G1H = "kernel_gate_1_h"

_WEIGHTS_G2 = "kernel_gate_2"
_WEIGHTS_G2H = "kernel_gate_2_h"

_WEIGHTS_G3 = "kernel_gate_3"
_WEIGHTS_G3H = "kernel_gate_3_h"


_BIAS_FC0 = "bias_fc0"
_WEIGHTS_FC0 = "kernel_fc0"

_BIAS_FC1 = "bias_fc1"
_WEIGHTS_FC1 = "kernel_fc1"

_BIAS_FC2 = "bias_fc2"
_WEIGHTS_FC2 = "kernel_fc2"


def getW(name, dim1, dim2, init, dtype):
    return vs.get_variable(name, [dim1, dim2], dtype=dtype, initializer=init)

def getB(name, dim, init, dtype):
    return vs.get_variable(name, [dim], dtype=dtype, initializer=init)


class GUnit(RNNCell):
    """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
    def __init__(self,
        num_units,
        drop,
        is_training,
        n_timestamps,
        activation=None,
        reuse=None,
        kernel_initializer=None,
        bias_initializer=None):
        super(GUnit, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._drop = drop
        self._is_training = is_training
        self._activation = activation or math_ops.tanh
        self._kernel_initializer = kernel_initializer
        self._bias_initializer = bias_initializer
        self._n_timestamps = n_timestamps


    @property
    def state_size(self):
        return self._num_units

    @property
    def output_size(self):
        return self._num_units


    def call(self, inputs, state):
        #with vs.variable_scope("gates"):
            #bias_ones = self._bias_initializer
            #if self._bias_initializer is None:
            #    dtype = [a.dtype for a in [inputs, state]][0]
                # bias_ones = init_ops.constant_initializer, dtype=dtype)
            #    bias_ones =init_ops.zeros_initializer(dtype=dtype)

        new_inputs = tf.split(inputs, self._n_timestamps,axis=1)
        new_inputs = tf.stack(new_inputs, axis=1)
        res = CNN(new_inputs, self._num_units, self._drop, self._is_training)
        return res, res