def apply(self, is_train, inputs, mask=None):
fw = self.fw(is_train)
bw_spec = self.fw if self.bw is None else self.bw
bw = bw_spec(is_train)
if self.merge is None:
return tf.concat(bidirectional_dynamic_rnn(fw, bw, inputs, mask, swap_memory=self.swap_memory,
dtype=tf.float32)[0], 2,)
else:
fw, bw = bidirectional_dynamic_rnn(fw, bw, inputs, mask,
swap_memory=self.swap_memory, dtype=tf.float32)[0]
return self.merge.apply(is_train, fw, bw) # TODO this should be in a different scope
评论列表
文章目录