LSTMEncDecAttn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def calcAttention(self, h1, hList, aList, encLen, cMBSize, args):
        # attention????????????????h1???
        if self.attn_mode == 0:
            return h1
        # 1, attention????????
        target1 = self.model.attnIn_L1(h1)  # ??????
        # (cMBSize, self.hDim) => (cMBSize, 1, self.hDim)
        target2 = chaFunc.expand_dims(target1, axis=1)
        # (cMBSize, 1, self.hDim) => (cMBSize, encLen, self.hDim)
        target3 = chaFunc.broadcast_to(target2, (cMBSize, encLen, self.hDim))
        # target3 = chaFunc.broadcast_to(chaFunc.reshape(
        #    target1, (cMBSize, 1, self.hDim)), (cMBSize, encLen, self.hDim))
        # 2, attention?????????
        if self.attn_mode == 1:  # bilinear
            # bilinear??attention?????hList1 == hList2 ???
            # shape: (cMBSize, encLen)
            aval = chaFunc.sum(target3 * aList, axis=2)
        elif self.attn_mode == 2:  # MLP
            # attnSum ????????
            t1 = chaFunc.reshape(target3, (cMBSize * encLen, self.hDim))
            # (cMBSize*encLen, self.hDim) => (cMBSize*encLen, 1)
            t2 = self.model.attnSum(chaFunc.tanh(t1 + aList))
            # shape: (cMBSize, encLen)
            aval = chaFunc.reshape(t2, (cMBSize, encLen))
            # aval = chaFunc.reshape(self.model.attnSum(
            #    chaFunc.tanh(t1 + aList)), (cMBSize, encLen))
        else:
            assert 0, "ERROR"
        # 3, softmax????
        cAttn1 = chaFunc.softmax(aval)   # (cMBSize, encLen)
        # 4, attention???????context vector????????
        # (cMBSize, encLen) => (cMBSize, 1, encLen)
        cAttn2 = chaFunc.expand_dims(cAttn1, axis=1)
        # (1, encLen) x (encLen, hDim) ?????(matmul)?cMBSize?????
        #     => (cMBSize, 1, hDim)
        cAttn3 = chaFunc.batch_matmul(cAttn2, hList)
        # cAttn3 = chaFunc.batch_matmul(chaFunc.reshape(
        #    cAttn1, (cMBSize, 1, encLen)), hList)
        # axis=1???1????????????
        context = chaFunc.reshape(cAttn3, (cMBSize, self.hDim))
        # 4, attention???????context vector????????
        # ??????????
        # (cMBSize, scrLen) => (cMBSize, scrLen, hDim)
        # cAttn2 = chaFunc.reshape(cAttn1, (cMBSize, encLen, 1))
        # (cMBSize, scrLen) => (cMBSize, scrLen, hDim)
        # cAttn3 = chaFunc.broadcast_to(cAttn2, (cMBSize, encLen, self.hDim))
        # ???????? (cMBSize, encLen, hDim)
        #     => (cMBSize, hDim)  # axis=1 ?????
        # context = chaFunc.sum(aList * cAttn3, axis=1)
        # 6, attention??????????
        c1 = chaFunc.concat((h1, context))
        c2 = self.model.attnOut_L2(c1)
        finalH = chaFunc.tanh(c2)
        # finalH = chaFunc.tanh(self.model.attnOut_L2(
        #    chaFunc.concat((h1, context))))
        return finalH  # context

    # ??????
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号