def __call__(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
with vs.variable_scope(scope or "MultiRNNC2DCell"):
cur_state_pos = 0
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with vs.variable_scope("Cell%d" % i):
if not nest.is_sequence(state):
raise ValueError(
"Expected state to be a tuple of length %d, but received: %s"
% (len(self.state_size), state))
cur_state = state[i]
cur_inp, new_state = cell(cur_inp, cur_state)
new_states.append(new_state)
new_states = (tuple(new_states) if self._state_is_tuple
else tf.concat(1, new_states))
return cur_inp, new_states
评论列表
文章目录