def compute_predictions_scan(self):
state = self.init_state
rnn_states = \
tf.scan(
self.rnn_step_scan,
tf.transpose(self.x, [1, 0, 2]),
initializer=state,
parallel_iterations=1)
rnn_outputs = \
tf.scan(
self.output_step_scan,
rnn_states,
initializer=tf.zeros([self.N_batch, self.N_out]),
parallel_iterations= 1)
return tf.transpose(rnn_outputs, [1, 0, 2]), tf.unstack(rnn_states)
# fix spectral radius of recurrent matrix
评论列表
文章目录