def __call__(self, inputs, initial_state=None, dtype=tf.float32, sequence_length=None, scope=None):
num_gates = 3 if self._with_residual else 2
transformed = tf.layers.dense(inputs, num_gates * self._num_units,
bias_initializer=tf.constant_initializer(self._constant_bias))
gates = tf.split(transformed, num_gates, axis=2)
forget_gate = tf.sigmoid(gates[1])
transformed_inputs = (1.0 - forget_gate) * gates[0]
if self._with_residual:
residual_gate = tf.sigmoid(gates[2])
inputs *= (1.0 - residual_gate)
new_inputs = tf.concat([inputs, transformed_inputs, forget_gate, residual_gate], axis=2)
else:
new_inputs = tf.concat([transformed_inputs, forget_gate], axis=2)
return self._rnn(new_inputs, initial_state, dtype, sequence_length, scope)
评论列表
文章目录