joint_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:age 作者: ly015 项目源码 文件源码
def _age_gradient(self, feat_age):
        '''
        compute age branch gradient direction in age_embed layer

        input:
            feat_age: output of age_embed layer (before relu)
        '''

        cls = self.age_cls

        feat = feat_age.detach()
        feat.requires_grad = True
        feat.volatile = False
        feat = feat.clone()
        feat.retain_grad()

        age_fc_out = cls.cls(cls.relu(feat))

        if self.opts.cls_type == 'dex':
            # deep expectation
            age_scale = np.arange(self.opts.min_age, self.opts.max_age + 1, 1.0)
            age_scale = Variable(age_fc_out.data.new(age_scale)).unsqueeze(1)

            age_out = torch.matmul(F.softmax(age_fc_out), age_scale).view(-1)

        elif self.opts.cls_type == 'oh':
            # ordinal hyperplane
            age_fc_out = F.sigmoid(age_fc_out)
            age_out = age_fc_out.sum(dim = 1) + self.opts.min_age

        elif self.opts.cls_type == 'reg':
            # regression
            age_out = self.age_fc_out.view(-1) + self.opts.min_age


        age_out.sum().backward()
        age_grad = feat.grad

        # normalization
        age_grad = age_grad / age_grad.norm(p = 2, dim = 1, keepdim = True)

        age_grad.detach_()
        age_grad.volatile = False
        age_grad.requires_grad = False

        cls.cls.zero_grad()

        return age_grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号