def batch_norm(x, params, stats, base, mode):
return F.batch_norm(x, weight=params[base + '.weight'],
bias=params[base + '.bias'],
running_mean=stats[base + '.running_mean'],
running_var=stats[base + '.running_var'],
training=mode)
评论列表
文章目录