def __call__(self, x, states_prev, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or type(self).__name__):
x_shape = x.get_shape().with_rank(2)
if not x_shape[1]:
raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape))
if len(states_prev) != 2:
raise ValueError("Expecting states_prev to be a tuple with length 2.")
input_size = x_shape[1]
w = vs.get_variable("W", [input_size + self._num_units,
self._num_units * 4])
b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]],
initializer=init_ops.constant_initializer(0.0))
wci = vs.get_variable("wci", [self._num_units])
wco = vs.get_variable("wco", [self._num_units])
wcf = vs.get_variable("wcf", [self._num_units])
(cs_prev, h_prev) = states_prev
(_, cs, _, _, _, _, h) = _lstm_block_cell(x,
cs_prev,
h_prev,
w,
b,
wci=wci,
wco=wco,
wcf=wcf,
forget_bias=self._forget_bias,
use_peephole=self._use_peephole)
return (h, (cs, h))
评论列表
文章目录