def forward(self, x):
x = self.embed(x) # (N,W,D)
x = self.dropout_embed(x)
x = x.unsqueeze(1) # (N,Ci,W,D)
if self.args.batch_normalizations is True:
x = [self.convs1_bn(F.tanh(conv(x))).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
else:
# x = [self.dropout(F.relu(conv(x)).squeeze(3)) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
# x = [self.dropout(F.tanh(conv(x)).squeeze(3)) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
# x = [F.tanh(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
x = torch.cat(x, 1)
x = self.dropout(x) # (N,len(Ks)*Co)
if self.args.batch_normalizations is True:
x = self.fc1_bn(self.fc1(x))
logit = self.fc2_bn(self.fc2(F.tanh(x)))
else:
logit = self.fc(x)
return logit
model_CNN.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录