class_nll_criterion.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:PyFunt 作者: dnlcrl 项目源码 文件源码
def __init__(self, weights=None, size_average=None):
        super(ClassNLLCriterion, self).__init__()
        if size_average:
            self.size_average = size_average
        else:
            self.size_average = True

        if weights:
            # assert(weights:dim() == 1, "weights input should be 1-D Tensor")
            self.weights = weights
        self.output_tensor = np.zeros(1)
        self.total_weight_tensor = np.ones(1)
        self.target = np.zeros(1)  # , dtype=np.long)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号