def test_load_parameter_dict(self):
l = nn.Linear(5, 5)
block = nn.Container(
conv=nn.Conv2d(3, 3, 3, bias=False)
)
net = nn.Container(
linear1=l,
linear2=l,
block=block,
empty=None,
)
param_dict = {
'linear1.weight': Variable(torch.ones(5, 5)),
'block.conv.bias': Variable(torch.range(1, 3)),
}
net.load_parameter_dict(param_dict)
self.assertIs(net.linear1.weight, param_dict['linear1.weight'])
self.assertIs(net.block.conv.bias, param_dict['block.conv.bias'])
评论列表
文章目录