recurrent_layers.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def apply(self, is_train, inputs, mask=None):
        inputs = tf.transpose(inputs, [1, 0, 2])  # to time first
        with tf.variable_scope("forward"):
            cell = LSTMBlockFusedCell(self.n_units, use_peephole=self.use_peepholes)
            fw = cell(inputs, dtype=tf.float32, sequence_length=mask)[0]
        with tf.variable_scope("backward"):
            cell = LSTMBlockFusedCell(self.n_units, use_peephole=self.use_peepholes)
            inputs = tf.reverse_sequence(inputs, mask, seq_axis=0, batch_axis=1)
            bw = cell(inputs, dtype=tf.float32, sequence_length=mask)[0]
            bw = tf.reverse_sequence(bw, mask, seq_axis=0, batch_axis=1)
        out = tf.concat([fw, bw], axis=2)
        out = tf.transpose(out, [1, 0, 2])  # back to batch first
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号