def setUttEncoder(module): # set utterance encoder to the module
if SharedModel.args.utt_enc_noise == True:
module.uttEncNoise = Variable(torch.FloatTensor(), volatile=True)
if SharedModel.args.no_cuda == False:
module.uttEncNoise = module.uttEncNoise.cuda()
if SharedModel.args.utt_enc_type >= 2:
module.uttEncoder = nn.ModuleList()
for i in [int(x) for x in SharedModel.args.conv_filters.split('_')]:
module.uttEncoder.append( nn.Conv1d(2*SharedModel.args.hid_dim * (2 if SharedModel.args.attn == 2 else 1), SharedModel.args.conv_out_dim, i, 1, int(math.ceil((i-1)/2))) )
if SharedModel.args.utt_enc_bn == True:
uttEncOutSize = 2 * SharedModel.args.hid_dim
if SharedModel.args.utt_enc_type >= 2:
uttEncOutSize = 3 * SharedModel.args.conv_out_dim
elif SharedModel.args.attn == 2:
uttEncOutSize = 4 * SharedModel.args.hid_dim
module.uttBn = nn.BatchNorm1d(uttEncOutSize)
评论列表
文章目录