def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or type(self).__name__):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(1, 2, state)
concat = _linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
eps = 1e-13
temp = math_ops.div(math_ops.reduce_sum(math_ops.mul(c, new_c),1),math_ops.reduce_sum(math_ops.mul(c,c),1) + eps)
dummy = array_ops.transpose(c)
t1 = math_ops.mul(dummy, temp)
t1 = array_ops.transpose(t1)
distract_c = new_c - t1
new_h = self._activation(distract_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat(1, [new_c, new_h])
return new_h, new_state
rnn_cell.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录