def test_interpolate(self):
for dev in ['/gpu:0']:
batch_size = 3
h = 3
w = 4
d = 3
grid_shape = [batch_size, h, w, d, 1]
grid_data = np.zeros(grid_shape).astype(np.float32)
grid_data[:, :, :, 1 :] = 1.0
grid_data[:, :, :, 2 :] = 2.0
guide_shape = [batch_size, 5, 9]
target_shape = [batch_size, 5, 9, 1]
for val in range(d):
target_data = val*np.ones(target_shape)
target_data = target_data.astype(np.float32)
guide_data = ((val+0.5)/(1.0*d))*np.ones(guide_shape).astype(np.float32)
output_data = self.run_bilateral_slice(dev, grid_data, guide_data)
diff = np.amax(np.abs(target_data-output_data))
self.assertEqual(target_shape, list(output_data.shape))
self.assertLess(diff, 5e-4)
评论列表
文章目录