encoder.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:sgnmt 作者: ucam-smt 项目源码 文件源码
def apply(self, source_sentence, source_sentence_mask):
        """Produces source annotations, either non-recurrently or with
        a bidirectional RNN architecture.
        """
        # Time as first dimension
        source_sentence = source_sentence.T
        source_sentence_mask = source_sentence_mask.T

        embeddings = self.lookup.apply(source_sentence)

        if self.n_layers >= 1:
            representation = self.bidir.apply(
                merge(self.fwd_fork.apply(embeddings, as_dict=True),
                      {'mask': source_sentence_mask}),
                merge(self.back_fork.apply(embeddings, as_dict=True),
                      {'mask': source_sentence_mask})
            )
            for _ in xrange(self.n_layers-1):
                if self.skip_connections:
                    inp = tensor.concatenate([representation, embeddings],
                                             axis=2)
                else:
                    inp = representation
                representation = self.bidir.apply(
                    merge(self.mid_fwd_fork.apply(inp, as_dict=True),
                          {'mask': source_sentence_mask}),
                    merge(self.mid_back_fork.apply(inp, as_dict=True),
                          {'mask': source_sentence_mask})
                )
        else:
            representation = embeddings
        return representation, source_sentence_mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号