LSTMEncDecAttn.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def encodeSentenceFWD(self, train_mode, sentence, args, dropout_rate):
        if args.gpu_enc != args.gpu_dec:  # enc?dec??GPU???
            chainer.cuda.get_device(args.gpu_enc).use()
        encLen = len(sentence)  # ??
        cMBSize = len(sentence[0])  # minibatch size

        # ?????embedding???  ??????????
        encEmbList = self.getEncoderInputEmbeddings(sentence, args)

        flag_train = (train_mode > 0)
        lstmVars = [0] * self.n_layers * 2
        if self.flag_merge_encfwbw == 0:  # fw?bw??????????????
            hyf, cyf, fwHout = self.model.encLSTM_f(
                None, None, encEmbList, flag_train, args)  # ???
            hyb, cyb, bkHout = self.model.encLSTM_b(
                None, None, encEmbList[::-1], flag_train, args)  # ???
            for z in six.moves.range(self.n_layers):
                lstmVars[2 * z] = cyf[z] + cyb[z]
                lstmVars[2 * z + 1] = hyf[z] + hyb[z]
        elif self.flag_merge_encfwbw == 1:  # fw?bw????????
            sp = (cMBSize, self.hDim)
            for z in six.moves.range(self.n_layers):
                if z == 0:  # ??? embedding???
                    biH = encEmbList
                else:  # ????? ????????
                    # ????????bkHout????????????
                    biH = fwHout + bkHout[::-1]
                # z?????
                hyf, cyf, fwHout = self.model.encLSTM_f(
                    z, biH, flag_train, dropout_rate, args)
                # z??????
                hyb, cyb, bkHout = self.model.encLSTM_b(
                    z, biH[::-1], flag_train, dropout_rate, args)
                # ??????????????????????????
                # ???????
                lstmVars[2 * z] = chaFunc.reshape(cyf + cyb, sp)
                lstmVars[2 * z + 1] = chaFunc.reshape(hyf + hyb, sp)
        else:
            assert 0, "ERROR"

        # ?????
        if self.flag_enc_boseos == 0:  # default
            # fwHout?[:,]???????????
            biHiddenStack = fwHout[:, ] + bkHout[::-1]
        elif self.flag_enc_boseos == 1:
            bkHout2 = bkHout[::-1]  # ?????
            biHiddenStack = fwHout[1:encLen - 1, ] + bkHout2[1:encLen - 1, ]
            # BOS, EOS?????? TODO ??????0??????????
            encLen -= 2
        else:
            assert 0, "ERROR"
        # (enc????, minibatch??, ??????)
        #    => (minibatch??, enc????, ??????)???
        biHiddenStackSW01 = chaFunc.swapaxes(biHiddenStack, 0, 1)
        # ?LSTM???????????decoder?LSTM????????
        lstmVars = chaFunc.stack(lstmVars)
        # encoder????encInfoObject???????
        retO = self.encInfoObject(biHiddenStackSW01, lstmVars, encLen, cMBSize)
        return retO
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号