model_sepEmbSepTags.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:seq_tagger 作者: OSU-slatelab 项目源码 文件源码
def getAttnOutput(input, attnScorer, winSize=0):    # get attention output following [Liu and Lane, Interspeech 2016]. the input is seqlen X batchsize X dim. if winSize is 0, all the time steps are used for the weigted averaging
    attnSeq = []
    for i in range(input.size(0)):
        curSeq = []
        if i > 0:
            leftBegin = 0
            if winSize > 0:
                leftBegin = max(0, i-winSize)
            curSeq.append(input[leftBegin:i])
        if i < input.size(0):
            leftEnd = input.size(0)
            if winSize > 0:
                leftEnd = min(i+winSize+1, input.size(0))
            curSeq.append(input[i:leftEnd])
        curSeq = torch.cat(curSeq, 0)
        cur = input[i:i+1].expand_as(curSeq)

        attnScores = attnScorer( torch.cat([cur, curSeq], 2).view(-1, 2*input.size(2)) )    # get attention scores
        transAttnScores = attnScores.view(curSeq.size(0), input.size(1)).transpose(0, 1)    # batchSize X curSeqLen
        smOut = F.softmax(transAttnScores).transpose(0, 1)
        smOutSeq = smOut.unsqueeze(2).expand_as(curSeq)
        weightedAvgSeq = (curSeq * smOutSeq).sum(0)
        attnSeq.append(weightedAvgSeq)
    attnSeq = torch.cat(attnSeq, 0)
    return torch.cat([input, attnSeq], 2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号