def _age_gradient(self, feat_age):
'''
compute age branch gradient direction in age_embed layer
input:
feat_age: output of age_embed layer (before relu)
'''
cls = self.age_cls
feat = feat_age.detach()
feat.requires_grad = True
feat.volatile = False
feat = feat.clone()
feat.retain_grad()
age_fc_out = cls.cls(cls.relu(feat))
if self.opts.cls_type == 'dex':
# deep expectation
age_scale = np.arange(self.opts.min_age, self.opts.max_age + 1, 1.0)
age_scale = Variable(age_fc_out.data.new(age_scale)).unsqueeze(1)
age_out = torch.matmul(F.softmax(age_fc_out), age_scale).view(-1)
elif self.opts.cls_type == 'oh':
# ordinal hyperplane
age_fc_out = F.sigmoid(age_fc_out)
age_out = age_fc_out.sum(dim = 1) + self.opts.min_age
elif self.opts.cls_type == 'reg':
# regression
age_out = self.age_fc_out.view(-1) + self.opts.min_age
age_out.sum().backward()
age_grad = feat.grad
# normalization
age_grad = age_grad / age_grad.norm(p = 2, dim = 1, keepdim = True)
age_grad.detach_()
age_grad.volatile = False
age_grad.requires_grad = False
cls.cls.zero_grad()
return age_grad
评论列表
文章目录