def calcAttention(self, h1, hList, aList, encLen, cMBSize, args):
# attention????????????????h1???
if self.attn_mode == 0:
return h1
# 1, attention????????
target1 = self.model.attnIn_L1(h1) # ??????
# (cMBSize, self.hDim) => (cMBSize, 1, self.hDim)
target2 = chaFunc.expand_dims(target1, axis=1)
# (cMBSize, 1, self.hDim) => (cMBSize, encLen, self.hDim)
target3 = chaFunc.broadcast_to(target2, (cMBSize, encLen, self.hDim))
# target3 = chaFunc.broadcast_to(chaFunc.reshape(
# target1, (cMBSize, 1, self.hDim)), (cMBSize, encLen, self.hDim))
# 2, attention?????????
if self.attn_mode == 1: # bilinear
# bilinear??attention?????hList1 == hList2 ???
# shape: (cMBSize, encLen)
aval = chaFunc.sum(target3 * aList, axis=2)
elif self.attn_mode == 2: # MLP
# attnSum ????????
t1 = chaFunc.reshape(target3, (cMBSize * encLen, self.hDim))
# (cMBSize*encLen, self.hDim) => (cMBSize*encLen, 1)
t2 = self.model.attnSum(chaFunc.tanh(t1 + aList))
# shape: (cMBSize, encLen)
aval = chaFunc.reshape(t2, (cMBSize, encLen))
# aval = chaFunc.reshape(self.model.attnSum(
# chaFunc.tanh(t1 + aList)), (cMBSize, encLen))
else:
assert 0, "ERROR"
# 3, softmax????
cAttn1 = chaFunc.softmax(aval) # (cMBSize, encLen)
# 4, attention???????context vector????????
# (cMBSize, encLen) => (cMBSize, 1, encLen)
cAttn2 = chaFunc.expand_dims(cAttn1, axis=1)
# (1, encLen) x (encLen, hDim) ?????(matmul)?cMBSize?????
# => (cMBSize, 1, hDim)
cAttn3 = chaFunc.batch_matmul(cAttn2, hList)
# cAttn3 = chaFunc.batch_matmul(chaFunc.reshape(
# cAttn1, (cMBSize, 1, encLen)), hList)
# axis=1???1????????????
context = chaFunc.reshape(cAttn3, (cMBSize, self.hDim))
# 4, attention???????context vector????????
# ??????????
# (cMBSize, scrLen) => (cMBSize, scrLen, hDim)
# cAttn2 = chaFunc.reshape(cAttn1, (cMBSize, encLen, 1))
# (cMBSize, scrLen) => (cMBSize, scrLen, hDim)
# cAttn3 = chaFunc.broadcast_to(cAttn2, (cMBSize, encLen, self.hDim))
# ???????? (cMBSize, encLen, hDim)
# => (cMBSize, hDim) # axis=1 ?????
# context = chaFunc.sum(aList * cAttn3, axis=1)
# 6, attention??????????
c1 = chaFunc.concat((h1, context))
c2 = self.model.attnOut_L2(c1)
finalH = chaFunc.tanh(c2)
# finalH = chaFunc.tanh(self.model.attnOut_L2(
# chaFunc.concat((h1, context))))
return finalH # context
# ??????
评论列表
文章目录