def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with num_units cells."""
with tf.variable_scope(scope or type(self).__name__):
with tf.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
concat = rnn_ops.linear([inputs, state], 2 * self._num_units, True, bias_start=1.0)
r, u = tf.split(value=concat, num_or_size_splits=2, axis=1)
if self._layer_norm:
r = rnn_ops.layer_norm(r, name="r")
u = rnn_ops.layer_norm(u, name="u")
# Apply non-linearity after layer normalization
r = tf.sigmoid(r)
u = tf.sigmoid(u)
with tf.variable_scope("candidate"):
c = self._activation(rnn_ops.linear([inputs, r * state], self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
basic_rnn_cells.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录