def test_container_copy(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(4, 5)
def forward(self, input):
return self.linear(input)
input = Variable(torch.randn(2, 4))
model = Model()
model_cp = deepcopy(model)
self.assertEqual(model(input).data, model_cp(input).data)
model_cp.linear.weight.data[:] = 2
self.assertNotEqual(model(input).data, model_cp(input).data)
评论列表
文章目录