def fwdUttEnc(module, rnnOut): # forward utterance encoder to get the utterance representation
uttEncOut = None
if SharedModel.args.utt_enc_type == 0: # Encoding by summation
uttEncOut = rnnOut.sum(0).squeeze(0)
elif SharedModel.args.utt_enc_type == 1: # Encoding by mean
uttEncOut = rnnOut.mean(0).squeeze(0)
else: # Encoding by CNN
uttEncOut = []
for i, curConv in enumerate(module.uttEncoder):
curConvInput = rnnOut.permute(1, 2, 0)
curConvOut = curConv(curConvInput)
curPoolOut = None
if SharedModel.args.utt_enc_type == 2: # using average pooling
curPoolOut = F.avg_pool1d(curConvOut, curConvOut.data.size(2))
else: # using max pooling
curPoolOut = F.max_pool1d(curConvOut, curConvOut.data.size(2))
uttEncOut.append(curPoolOut)
uttEncOut = torch.cat(uttEncOut, 1)
uttEncOut = uttEncOut.squeeze(2)
uttEncOut = F.tanh(uttEncOut)
if SharedModel.args.utt_enc_noise == True:
module.uttEncNoise.data.resize_(uttEncOut.size()).normal_(0, 0.1) # Add white noises to the utterance encoding
uttEncOut.add_(module.uttEncNoise)
if SharedModel.args.utt_enc_bn == True:
uttEncOut = module.uttBn(uttEncOut)
return uttEncOut
评论列表
文章目录