ssn_ops.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:action-detection 作者: yjxiong 项目源码 文件源码
def forward(self, pred, labels, targets):
        indexer = labels.data - 1
        prep = pred[:, indexer, :]
        class_pred = torch.cat((torch.diag(prep[:, :,  0]).view(-1, 1),
                                torch.diag(prep[:, :, 1]).view(-1, 1)),
                               dim=1)
        loss = self.smooth_l1_loss(class_pred.view(-1), targets.view(-1)) * 2
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号