yellowfin.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dawn-bench-models 作者: stanford-futuredata 项目源码 文件源码
def grad_sparsity(self):
    global_state = self._global_state
    if self._iter == 0:
      global_state["sparsity_avg"] = 0.0
    non_zero_cnt = 0.0
    all_entry_cnt = 0.0
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue
        grad = p.grad.data
        grad_non_zero = grad.nonzero()
        if grad_non_zero.dim() > 0:
          non_zero_cnt += grad_non_zero.size()[0]
        all_entry_cnt += torch.numel(grad)
    beta = self._beta
    global_state["sparsity_avg"] = beta * global_state["sparsity_avg"] \
      + (1 - beta) * non_zero_cnt / float(all_entry_cnt)
    self._sparsity_avg = \
      global_state["sparsity_avg"] / self.zero_debias_factor()
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号