def __call__(self, inputs, state, scope=None):
"""LSTM cell with layer normalization and recurrent dropout."""
with vs.variable_scope(scope or type(self).__name__) as scope: # LayerNormBasicLSTMCell # pylint: disable=unused-variables
c, h = state
args = array_ops.concat(1, [inputs, h])
concat = self._linear(args)
i, j, f, o = array_ops.split(1, 4, concat)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
f = self._norm(f, "forget")
o = self._norm(o, "output")
g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state")
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_state = rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state
评论列表
文章目录