ranking.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:SentEval 作者: facebookresearch 项目源码 文件源码
def forward(self, img, sent, imgc, sentc):
        # imgc : (bsize, ncontrast, imgdim)
        # sentc : (bsize, ncontrast, sentdim)
        # img : (bsize, imgdim)
        # sent : (bsize, sentdim)
        img = img.unsqueeze(1).expand_as(imgc).contiguous()
        img = img.view(-1, self.imgdim)
        imgc = imgc.view(-1, self.imgdim)
        sent = sent.unsqueeze(1).expand_as(sentc).contiguous()
        sent = sent.view(-1, self.sentdim)
        sentc = sentc.view(-1, self.sentdim)

        imgproj = self.imgproj(img)
        imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj)
        imgcproj = self.imgproj(imgc)
        imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj)
        sentproj = self.sentproj(sent)
        sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj)
        sentcproj = self.sentproj(sentc)
        sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj)
        # (bsize*ncontrast, projdim)

        anchor1 = torch.sum((imgproj*sentproj), 1)
        anchor2 = torch.sum((sentproj*imgproj), 1)
        img_sentc = torch.sum((imgproj*sentcproj), 1)
        sent_imgc = torch.sum((sentproj*imgcproj), 1)

        # (bsize*ncontrast)
        return anchor1, anchor2, img_sentc, sent_imgc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号