FastText2.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:PyTorchText 作者: chenyuntc 项目源码 文件源码
def forward(self,title,content):
        title_em = self.encoder(title)
        content_em = self.encoder(content)
        title_size = title_em.size()
        content_size = content_em.size()

        title_2 = t.nn.functional.relu(self.bn(self.pre_fc(title_em.view(-1,256)).view(title_em.size(0),title_em.size(1),-1).transpose(1,2).contiguous()))
        content_2 = t.nn.functional.relu(self.bn2(self.pre_fc2(content_em.view(-1,256)).view(content_em.size(0),content_em.size(1),-1).transpose(1,2)).contiguous())

        # title_2 = self.pre(title_em.contiguous().view(-1,256)).view(title_size)
        # content_2 = self.pre(content_em.contiguous().view(-1,256)).view(content_size)


        title_ = t.mean(title_2,dim=2)
        content_ = t.mean(content_2,dim=2)
        inputs=t.cat((title_.squeeze(),content_.squeeze()),1)
        out=self.fc(inputs.view(inputs.size(0),-1))
        # content_out=self.content_fc(content.view(content.size(0),-1))
        # out=torch.cat((title_out,content_out),1)
        # out=self.fc(out)
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号