def __call__(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
bias_ones = self._bias_initializer
if self._bias_initializer is None:
dtype = [a.dtype for a in [inputs, state]][0]
bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
value = rnn_cell_impl._linear([inputs, state], 2 * self._num_units, True, bias_ones,\
self._kernel_initializer)
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
r,u=layer_normalization(r,scope="r/"),layer_normalization(u,scope="u/")
r,u=math_ops.sigmoid(r),math_ops.sigmoid(u)
with vs.variable_scope("candidate"):
c = self._activation(rnn_cell_impl._linear([inputs, r * state], self._num_units, True, self._bias_initializer, self._kernel_initializer))
new_h = u * state + (1 - u) * c
return new_h, new_h
评论列表
文章目录