def tensor_swirl(image, center=None, strength=1, radius=100, rotation=0, cval=0.0, **kwargs):
# **kwargs is for unsupported options (ignored)
cval = tf.fill(K.shape(image)[0:1], cval)
shape = K.int_shape(image)[1:3]
if center is None:
center = np.array(shape) / 2
ys = np.expand_dims(np.repeat(np.arange(shape[0]), shape[1]),-1)
xs = np.expand_dims(np.tile (np.arange(shape[1]), shape[0]),-1)
map_xs, map_ys = swirl_mapping(xs, ys, center, rotation, strength, radius)
mapping = np.zeros((*shape, *shape))
for map_x, map_y, x, y in zip(map_xs, map_ys, xs, ys):
results = tensor_linear_interpolation(image, map_x, map_y, cval)
for _y, _x, w in results:
# mapping[int(y),int(x),int(_y),int(_x),] = w
mapping[int(_y),int(_x),int(y),int(x),] = w
results = tf.tensordot(image, K.variable(mapping), [[1,2],[0,1]])
# results = K.reshape(results, K.shape(image))
return results
评论列表
文章目录