def forward(self, x):
n_idx = 0
c_idx = 1
h_idx = 2
w_idx = 3
x = self.lookup_table(x)
x = x.unsqueeze(c_idx)
enc_outs = []
for encoder in self.encoders:
enc_ = F.relu(encoder(x))
k_h = enc_.size()[h_idx]
enc_ = F.max_pool2d(enc_, kernel_size=(k_h, 1))
enc_ = enc_.squeeze(w_idx)
enc_ = enc_.squeeze(h_idx)
enc_outs.append(enc_)
encoding = self.dropout(torch.cat(enc_outs, 1))
return F.log_softmax(self.logistic(encoding))
评论列表
文章目录