def f(params, inputs, mode):
o = inputs.view(inputs.size(0), 1, 28, 28)
o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2)
o = F.relu(o)
o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2)
o = F.relu(o)
o = o.view(o.size(0), -1)
o = F.linear(o, params['linear2.weight'], params['linear2.bias'])
o = F.relu(o)
o = F.linear(o, params['linear3.weight'], params['linear3.bias'])
return o
评论列表
文章目录