def getAttnOutput(input, attnScorer, winSize=0): # get attention output following [Liu and Lane, Interspeech 2016]. the input is seqlen X batchsize X dim. if winSize is 0, all the time steps are used for the weigted averaging
attnSeq = []
for i in range(input.size(0)):
curSeq = []
if i > 0:
leftBegin = 0
if winSize > 0:
leftBegin = max(0, i-winSize)
curSeq.append(input[leftBegin:i])
if i < input.size(0):
leftEnd = input.size(0)
if winSize > 0:
leftEnd = min(i+winSize+1, input.size(0))
curSeq.append(input[i:leftEnd])
curSeq = torch.cat(curSeq, 0)
cur = input[i:i+1].expand_as(curSeq)
attnScores = attnScorer( torch.cat([cur, curSeq], 2).view(-1, 2*input.size(2)) ) # get attention scores
transAttnScores = attnScores.view(curSeq.size(0), input.size(1)).transpose(0, 1) # batchSize X curSeqLen
smOut = F.softmax(transAttnScores).transpose(0, 1)
smOutSeq = smOut.unsqueeze(2).expand_as(curSeq)
weightedAvgSeq = (curSeq * smOutSeq).sum(0)
attnSeq.append(weightedAvgSeq)
attnSeq = torch.cat(attnSeq, 0)
return torch.cat([input, attnSeq], 2)
评论列表
文章目录