def train_stats(m, trainloader, param_list = None):
stats = {}
params = filtered_params(m, param_list)
counts = 0,0
for counts in enumerate(accumulate((reduce(lambda d1,d2: d1*d2, p[1].size()) for p in params)) ):
pass
stats['variables_optimized'] = counts[0] + 1
stats['params_optimized'] = counts[1]
before = time.time()
losses = train(m, trainloader, param_list=param_list)
stats['training_time'] = time.time() - before
stats['training_loss'] = losses[-1] if len(losses) else float('nan')
stats['training_losses'] = losses
return stats
train.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录