model_sepEmbSepTags.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:seq_tagger 作者: OSU-slatelab 项目源码 文件源码
def fwdUttEnc(module, rnnOut):  # forward utterance encoder to get the utterance representation
    uttEncOut = None
    if SharedModel.args.utt_enc_type == 0:  # Encoding by summation
        uttEncOut = rnnOut.sum(0).squeeze(0)
    elif SharedModel.args.utt_enc_type == 1:    # Encoding by mean
        uttEncOut = rnnOut.mean(0).squeeze(0)
    else:   # Encoding by CNN
        uttEncOut = []
        for i, curConv in enumerate(module.uttEncoder):
            curConvInput = rnnOut.permute(1, 2, 0)
            curConvOut = curConv(curConvInput)
            curPoolOut = None
            if SharedModel.args.utt_enc_type == 2:  # using average pooling
                curPoolOut = F.avg_pool1d(curConvOut, curConvOut.data.size(2))
            else:   # using max pooling
                curPoolOut = F.max_pool1d(curConvOut, curConvOut.data.size(2))
            uttEncOut.append(curPoolOut)
        uttEncOut = torch.cat(uttEncOut, 1)
        uttEncOut = uttEncOut.squeeze(2)
        uttEncOut = F.tanh(uttEncOut)

    if SharedModel.args.utt_enc_noise == True:
        module.uttEncNoise.data.resize_(uttEncOut.size()).normal_(0, 0.1)   # Add white noises to the utterance encoding
        uttEncOut.add_(module.uttEncNoise)

    if SharedModel.args.utt_enc_bn == True:
        uttEncOut = module.uttBn(uttEncOut)
    return uttEncOut
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号