def forward(self, img, txt_feat):
img_feat = self.encoder(img)
img_feat = F.leaky_relu(img_feat + self.residual_branch(img_feat), 0.2)
txt_feat = self.compression(txt_feat)
txt_feat = txt_feat.unsqueeze(-1).unsqueeze(-1)
txt_feat = txt_feat.repeat(1, 1, img_feat.size(2), img_feat.size(3))
fusion = torch.cat((img_feat, txt_feat), dim=1)
output = self.classifier(fusion)
return output.squeeze()
评论列表
文章目录