def test_pytorch_backward(num_classes):
bounds = (0, 255)
channels = num_classes
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, x):
x = torch.mean(x, 3)
x = torch.squeeze(x, dim=3)
x = torch.mean(x, 2)
x = torch.squeeze(x, dim=2)
logits = x
return logits
model = Net()
model = PyTorchModel(
model,
bounds=bounds,
num_classes=num_classes,
cuda=False)
test_image = np.random.rand(channels, 5, 5).astype(np.float32)
test_grad_pre = np.random.rand(num_classes).astype(np.float32)
test_grad = model.backward(test_grad_pre, test_image)
assert test_grad.shape == test_image.shape
manual_grad = np.repeat(np.repeat(
(test_grad_pre / 25.).reshape((-1, 1, 1)),
5, axis=1), 5, axis=2)
np.testing.assert_almost_equal(
test_grad,
manual_grad)
评论列表
文章目录