def forward(self, img, qst):
x = self.conv(img) ## x = (64 x 24 x 5 x 5)
"""g"""
mb = x.size()[0]
n_channels = x.size()[1]
d = x.size()[2]
# x_flat = (64 x 25 x 24)
x_flat = x.view(mb,n_channels,d*d).permute(0,2,1)
# add coordinates
x_flat = torch.cat([x_flat, self.coord_tensor],2)
# add question everywhere
qst = torch.unsqueeze(qst, 1)
qst = qst.repeat(1,25,1)
qst = torch.unsqueeze(qst, 2)
# cast all pairs against each other
x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26+11)
x_i = x_i.repeat(1,25,1,1) # (64x25x25x26+11)
x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26+11)
x_j = torch.cat([x_j,qst],3)
x_j = x_j.repeat(1,1,25,1) # (64x25x25x26+11)
# concatenate all together
x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+11)
# reshape for passing through network
x_ = x_full.view(mb*d*d*d*d,63)
x_ = self.g_fc1(x_)
x_ = F.relu(x_)
x_ = self.g_fc2(x_)
x_ = F.relu(x_)
x_ = self.g_fc3(x_)
x_ = F.relu(x_)
x_ = self.g_fc4(x_)
x_ = F.relu(x_)
# reshape again and sum
x_g = x_.view(mb,d*d*d*d,256)
x_g = x_g.sum(1).squeeze()
"""f"""
x_f = self.f_fc1(x_g)
x_f = F.relu(x_f)
return self.fcout(x_f)
评论列表
文章目录