ops_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:hdrnet_legacy 作者: mgharbi 项目源码 文件源码
def test_interpolate(self):
    for dev in ['/gpu:0']:
      batch_size = 3
      h = 3
      w = 4
      d = 3
      grid_shape = [batch_size, h, w, d, 1]
      grid_data = np.zeros(grid_shape).astype(np.float32)
      grid_data[:, :, :, 1 :] = 1.0
      grid_data[:, :, :, 2 :] = 2.0

      guide_shape = [batch_size, 5, 9]
      target_shape = [batch_size, 5, 9, 1]

      for val in range(d):
        target_data = val*np.ones(target_shape)
        target_data = target_data.astype(np.float32)

        guide_data = ((val+0.5)/(1.0*d))*np.ones(guide_shape).astype(np.float32)
        output_data = self.run_bilateral_slice(dev, grid_data, guide_data)
        diff = np.amax(np.abs(target_data-output_data))


        self.assertEqual(target_shape, list(output_data.shape))

        self.assertLess(diff, 5e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号