LSTMEncDecAttn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def prepareDecoder(self, encInfo):
        self.model.decLSTM.reset_state()
        if self.attn_mode == 0:
            aList = None
        elif self.attn_mode == 1:
            aList = encInfo.attnList
        elif self.attn_mode == 2:
            aList = self.model.attnM(
                chaFunc.reshape(encInfo.attnList,
                                (encInfo.cMBSize * encInfo.encLen, self.hDim)))
            # TODO: ???????encoder???????
        else:
            assert 0, "ERROR"
        xp = cuda.get_array_module(encInfo.lstmVars[0].data)
        finalHS = chainer.Variable(
            xp.zeros(
                encInfo.lstmVars[0].data.shape,
                dtype=xp.float32))  # ???input_feed?0????
        return aList, finalHS

    ############################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号