def test_sample(self):
import numpy as np
h, w = 3, 4
def np_sample(img, coords):
# a reference implementation
coords = np.maximum(coords, 0)
coords = np.minimum(coords,
np.array([img.shape[1]-1, img.shape[2]-1]))
xs = coords[:,:,:,1].reshape((img.shape[0], -1))
ys = coords[:,:,:,0].reshape((img.shape[0], -1))
ret = np.zeros((img.shape[0], coords.shape[1], coords.shape[2],
img.shape[3]), dtype='float32')
for k in range(img.shape[0]):
xss, yss = xs[k], ys[k]
ret[k,:,:,:] = img[k,yss,xss,:].reshape((coords.shape[1],
coords.shape[2], 3))
return ret
bimg = np.random.rand(2, h, w, 3).astype('float32')
#mat = np.array([
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]],
#[[[1,1], [1.2,1.2]], [[-1, -1], [2.5, 2.5]]]
#], dtype='float32') #2x2x2x2
mat = (np.random.rand(2, 5, 5, 2) - 0.2) * np.array([h + 3, w + 3])
true_res = np_sample(bimg, np.floor(mat + 0.5).astype('int32'))
inp, mapping = self.make_variable(bimg, mat)
output = sample(inp, tf.cast(tf.floor(mapping+0.5), tf.int32))
res = self.run_variable(output)
self.assertTrue((res == true_res).all())
评论列表
文章目录