def getM(mods):
for m in mods:
if isinstance(m, legacy.nn.SpatialConvolution):
m.gradWeight[m.gradWeight.ne(m.gradWeight)] = 0
l.append(torch.norm(m.gradWeight))
elif isinstance(m, legacy.nn.Linear):
l.append(torch.norm(m.gradWeight))
elif isinstance(m, legacy.nn.Concat) or \
isinstance(m, legacy.nn.Sequential):
getM(m.modules)
compare-pytorch-and-torch-grads.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录