baselines.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:clevr-iep 作者: facebookresearch 项目源码 文件源码
def forward(self, v, u):
    """
    Input:
    - v: N x D x H x W
    - u: N x D

    Returns:
    - next_u: N x D
    """
    N, K = v.size(0), self.hidden_dim
    D, H, W = v.size(1), v.size(2), v.size(3)
    v_proj = self.Wv(v) # N x K x H x W
    u_proj = self.Wu(u) # N x K
    u_proj_expand = u_proj.view(N, K, 1, 1).expand(N, K, H, W)
    h = F.tanh(v_proj + u_proj_expand)
    p = F.softmax(self.Wp(h).view(N, H * W)).view(N, 1, H, W)
    self.attention_maps = p.data.clone()

    v_tilde = (p.expand_as(v) * v).sum(2).sum(3).view(N, D)
    next_u = u + v_tilde
    return next_u
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号