def forward(self, x):
# print("aaaaa")
x_no_static = self.embed_no_static(x)
# x_no_static = self.dropout(x_no_static)
x_static = self.embed_static(x)
# fix the embedding
# x_static = Variable(x_static.data)
# x_static = self.dropout(x_static)
x = torch.stack([x_static, x_no_static], 1)
# x = x.unsqueeze(1) # (N,Ci,W,D)
x = self.dropout(x)
if self.args.batch_normalizations is True:
x = [F.relu(self.convs1_bn(conv(x))).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
else:
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N,Co), ...]*len(Ks)
x = torch.cat(x, 1)
'''
x1 = self.conv_and_pool(x,self.conv13) #(N,Co)
x2 = self.conv_and_pool(x,self.conv14) #(N,Co)
x3 = self.conv_and_pool(x,self.conv15) #(N,Co)
x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co)
'''
x = self.dropout(x) # (N,len(Ks)*Co)
if self.args.batch_normalizations is True:
x = self.fc1(x)
logit = self.fc2(F.relu(x))
else:
x = self.fc1(x)
logit = self.fc2(F.relu(x))
return logit
model_CNN_MUI.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录