def forward(self, img, sent, imgc, sentc):
# imgc : (bsize, ncontrast, imgdim)
# sentc : (bsize, ncontrast, sentdim)
# img : (bsize, imgdim)
# sent : (bsize, sentdim)
img = img.unsqueeze(1).expand_as(imgc).contiguous()
img = img.view(-1, self.imgdim)
imgc = imgc.view(-1, self.imgdim)
sent = sent.unsqueeze(1).expand_as(sentc).contiguous()
sent = sent.view(-1, self.sentdim)
sentc = sentc.view(-1, self.sentdim)
imgproj = self.imgproj(img)
imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj)
imgcproj = self.imgproj(imgc)
imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj)
sentproj = self.sentproj(sent)
sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj)
sentcproj = self.sentproj(sentc)
sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj)
# (bsize*ncontrast, projdim)
anchor1 = torch.sum((imgproj*sentproj), 1)
anchor2 = torch.sum((sentproj*imgproj), 1)
img_sentc = torch.sum((imgproj*sentcproj), 1)
sent_imgc = torch.sum((sentproj*imgcproj), 1)
# (bsize*ncontrast)
return anchor1, anchor2, img_sentc, sent_imgc
评论列表
文章目录