def __call__(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
with tf.variable_scope("MultiRNNCellWithConn"):
cur_state_pos = 0
first_layer_input = cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with tf.variable_scope("Cell%d" % i):
cur_state = tf.slice(
state, [0, cur_state_pos], [-1, cell.state_size])
cur_state_pos += cell.state_size
# Add skip connection from the input of current time t.
if i != 0:
first_layer_input = first_layer_input
else:
first_layer_input = tf.zeros_like(first_layer_input)
cur_inp, new_state = cell(tf.concat(1, [inputs, first_layer_input]), cur_state)
new_states.append(new_state)
return cur_inp, tf.concat(1, new_states)
评论列表
文章目录