def test_net_forward(self):
model = Net()
print(model)
self.assertEqual(model.conv1.out_channels, model.conv2.out_channels)
self.assertEqual(model.conv1.out_channels, model.conv3.in_channels)
self.assertEqual(model.conv2.out_channels, model.conv3.in_channels)
self.assertEqual(model.conv3.out_channels, model.conv4.in_channels)
# simple forward pass
input = Variable(torch.rand(1, 1, 4) * 2 - 1)
output = model(input)
self.assertEqual(output.size(), (1, 2, 4))
# feature split
model.conv1.split_feature(feature_i=1)
model.conv2.split_feature(feature_i=3)
print(model)
self.assertEqual(model.conv1.out_channels, model.conv2.out_channels)
self.assertEqual(model.conv1.out_channels, model.conv3.in_channels)
self.assertEqual(model.conv2.out_channels, model.conv3.in_channels)
self.assertEqual(model.conv3.out_channels, model.conv4.in_channels)
output2 = model(input)
diff = output - output2
dot = torch.dot(diff.view(-1), diff.view(-1))
# should be close to 0
#self.assertTrue(np.isclose(dot.data[0], 0., atol=1e-2))
print("mse: ", dot.data[0])
评论列表
文章目录