def forward(self, v, u):
"""
Input:
- v: N x D x H x W
- u: N x D
Returns:
- next_u: N x D
"""
N, K = v.size(0), self.hidden_dim
D, H, W = v.size(1), v.size(2), v.size(3)
v_proj = self.Wv(v) # N x K x H x W
u_proj = self.Wu(u) # N x K
u_proj_expand = u_proj.view(N, K, 1, 1).expand(N, K, H, W)
h = F.tanh(v_proj + u_proj_expand)
p = F.softmax(self.Wp(h).view(N, H * W)).view(N, 1, H, W)
self.attention_maps = p.data.clone()
v_tilde = (p.expand_as(v) * v).sum(2).sum(3).view(N, D)
next_u = u + v_tilde
return next_u
评论列表
文章目录