def apply(self, is_train, inputs, mask=None):
inputs = tf.transpose(inputs, [1, 0, 2]) # to time first
with tf.variable_scope("forward"):
cell = LSTMBlockFusedCell(self.n_units, use_peephole=self.use_peepholes)
fw = cell(inputs, dtype=tf.float32, sequence_length=mask)[0]
with tf.variable_scope("backward"):
cell = LSTMBlockFusedCell(self.n_units, use_peephole=self.use_peepholes)
inputs = tf.reverse_sequence(inputs, mask, seq_axis=0, batch_axis=1)
bw = cell(inputs, dtype=tf.float32, sequence_length=mask)[0]
bw = tf.reverse_sequence(bw, mask, seq_axis=0, batch_axis=1)
out = tf.concat([fw, bw], axis=2)
out = tf.transpose(out, [1, 0, 2]) # back to batch first
return out
评论列表
文章目录