def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
with tf.variable_scope("gates"):
input_to_gates = tf.layers.dense(
inputs, 2 * self._num_units, name="input_proj",
kernel_initializer=tf.glorot_normal_initializer(),
use_bias=self.use_input_bias)
# Nematus does the orthogonal initialization probably differently
state_to_gates = tf.layers.dense(
state, 2 * self._num_units,
use_bias=self.use_state_bias,
kernel_initializer=orthogonal_initializer(),
name="state_proj")
gates_input = state_to_gates + input_to_gates
reset, update = tf.split(
tf.sigmoid(gates_input), num_or_size_splits=2, axis=1)
with tf.variable_scope("candidate"):
input_to_candidate = tf.layers.dense(
inputs, self._num_units, use_bias=self.use_input_bias,
kernel_initializer=tf.glorot_normal_initializer(),
name="input_proj")
state_to_candidate = tf.layers.dense(
state, self._num_units, use_bias=self.use_state_bias,
kernel_initializer=orthogonal_initializer(),
name="state_proj")
candidate = self._activation(
state_to_candidate * reset + input_to_candidate)
new_state = update * state + (1 - update) * candidate
return new_state, new_state
评论列表
文章目录