def forward(self, X, S1, S2, config):
h = self.h(X)
r = self.r(h)
q = self.q(r)
v, _ = torch.max(q, dim=1, keepdim=True)
for i in range(0, config.k - 1):
q = F.conv2d(torch.cat([r, v], 1),
torch.cat([self.q.weight, self.w], 1),
stride=1,
padding=1)
v, _ = torch.max(q, dim=1, keepdim=True)
q = F.conv2d(torch.cat([r, v], 1),
torch.cat([self.q.weight, self.w], 1),
stride=1,
padding=1)
slice_s1 = S1.long().expand(config.imsize, 1, config.l_q, q.size(0))
slice_s1 = slice_s1.permute(3, 2, 1, 0)
q_out = q.gather(2, slice_s1).squeeze(2)
slice_s2 = S2.long().expand(1, config.l_q, q.size(0))
slice_s2 = slice_s2.permute(2, 1, 0)
q_out = q_out.gather(2, slice_s2).squeeze(2)
logits = self.fc(q_out)
return logits, self.sm(logits)
model.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录