def f(o, params, stats, mode):
o = F.batch_norm(o, running_mean=stats['bn.running_mean'],
running_var=stats['bn.running_var'],
weight=params['bn.weight'],
bias=params['bn.bias'], training=mode)
o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'])
o = F.relu(o)
o = o.view(o.size(0), -1)
o = F.linear(o, params['linear2.weight'], params['linear2.bias'])
o = F.relu(o)
o = F.linear(o, params['linear3.weight'], params['linear3.bias'])
return o
评论列表
文章目录