def forward(self,title,content):
title_em = self.encoder(title)
content_em = self.encoder(content)
title_size = title_em.size()
content_size = content_em.size()
title_2 = self.pre1(title_em.contiguous().view(-1,256)).view(title_size[0],title_size[1],-1)
content_2 = self.pre2(content_em.contiguous().view(-1,256)).view(content_size[0],content_size[1],-1)
title_ = t.mean(title_2,dim=1)
content_ = t.mean(content_2,dim=1)
inputs=t.cat((title_.squeeze(),content_.squeeze()),1)
out=self.fc(inputs)
# content_out=self.content_fc(content.view(content.size(0),-1))
# out=torch.cat((title_out,content_out),1)
# out=self.fc(out)
return out
评论列表
文章目录