def __call__(self, inputs, state, mask, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
# Parameters of gates are concatenated into one multiply for efficiency.
c, h = array_ops.split(1, 2, state)
concat = linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)
new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j)
mask = array_ops.expand_dims(mask, 1)
new_c = mask * new_c + (1. - mask) * c
new_h = tanh(new_c) * sigmoid(o)
new_h = mask * new_h + (1. - mask) * h
return new_h, array_ops.concat(1, [new_c, new_h])
评论列表
文章目录