misc.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:age 作者: ly015 项目源码 文件源码
def forward(self, fc, seq_len):
        '''
        fc: [bsz, max_len, fc_size], has already passed through sigmoid layer
        seq_len: [bsz]
        '''

        loss = Variable(torch.zeros(1), requires_grad = True)

        if fc.is_cuda:
            self.age_dis_trans.cuda()
            loss = loss.cuda()


        bsz, max_len = fc.size()[0:2]
        fc = fc.view(bsz * max_len, -1)        

        log_prob = F.log_softmax(self.age_dis_trans(fc)).view(bsz, max_len, -1)
        prob = log_prob.detach().exp()

        seq_len = seq_len.long()

        for i in range(bsz):
            l = seq_len.data[i]-1
            loss = loss + F.kl_div(log_prob[i,0:l], prob[i,1:(l+1)], False)/l
        loss = loss / bsz
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号