def calc (self, xs):
xs = tf.transpose(xs, [1, 0, 2])
print "xs: " + str(xs)
mlp_out = []
for i in range(self.lstm_steps_count):
v = self.mlp (tf.gather(xs, i))
mlp_out.append (v)
mlp_out = tf.transpose(tf.pack (mlp_out), [1, 0, 2])
val, state = tf.nn.dynamic_rnn(tf.nn.rnn_cell.MultiRNNCell(self.layers, state_is_tuple=True), mlp_out, dtype=tf.float32)
val = tf.transpose(val, [1, 0, 2])
results = []
for i in range(self.lstm_steps_count):
v = self.out_mlp (tf.gather(val, i))
results.append (v)
return tf.transpose(tf.pack (results), [1, 0, 2])
评论列表
文章目录