model.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:DCNMT 作者: SwordYork 项目源码 文件源码
def __init__(self, vocab_size, embedding_dim, igru_state_dim, igru_depth, trg_dgru_depth, emitter,
                 feedback_brick, merge=None, merge_prototype=None, post_merge=None, **kwargs):
        merged_dim = igru_state_dim
        if not merge:
            merge = Merge(input_names=kwargs['source_names'],
                          prototype=merge_prototype)
        if not post_merge:
            post_merge = Bias(dim=merged_dim)

        # for compatible
        if igru_depth == 1:
            self.igru = IGRU(dim=igru_state_dim)
        else:
            self.igru = RecurrentStack([IGRU(dim=igru_state_dim, name='igru')] +
                                       [UpperIGRU(dim=igru_state_dim, activation=Tanh(), name='upper_igru' + str(i))
                                        for i in range(1, igru_depth)],
                                       skip_connections=True)
        self.embedding_dim = embedding_dim
        self.emitter = emitter
        self.feedback_brick = feedback_brick
        self.merge = merge
        self.post_merge = post_merge
        self.merged_dim = merged_dim
        self.igru_depth = igru_depth
        self.trg_dgru_depth = trg_dgru_depth
        self.lookup = LookupTable(name='embeddings')
        self.vocab_size = vocab_size
        self.igru_state_dim = igru_state_dim
        self.gru_to_softmax = Linear(input_dim=igru_state_dim, output_dim=vocab_size)
        self.gru_fork = Fork([name for name in self.igru.apply.sequences
                              if name != 'mask' and name != 'input_states'], prototype=Linear(), name='gru_fork')

        children = [self.emitter, self.feedback_brick, self.merge, self.post_merge,
                    self.igru, self.lookup, self.gru_to_softmax, self.gru_fork]
        kwargs.setdefault('children', []).extend(children)
        super(Interpolator, self).__init__(**kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号