def __call__(self, input, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
# computation
c_prev, h_prev = state # TODO test
with tf.variable_scope("new_h"):
rec_input = _linear(h_prev, self._num_units, True)
new_h = tf.nn.tanh(rec_input + input)
# new_c, new_h
new_c = new_h
new_h = new_h
new_state = (LSTMStateTuple(new_c, new_h))
return new_h, new_state
评论列表
文章目录