def _build_ops(self):
i0 = tf.constant(0, dtype=tf.int32)
loop_condition = lambda i, inputs, state: tf.less(i, self.max_steps)
def body(i, inputs, full_state):
idx = i % self.num_cores
prev_state = full_state[idx]
inputs, full_state[idx] = self.shared_cell(inputs, prev_state)
return i+1, inputs, full_state
_, inputs, full_state = tf.while_loop(
loop_condition,
body,
loop_vars=[i0,
self.inputs,
self.initial_state])
评论列表
文章目录