def translate(U, theta, out_height, out_width):
num_batch = tf.shape(U)[0]
height, width, num_ch = U.get_shape()[1:]
height = height.value
width = width.value
num_ch = num_ch.value
hwc = height*width*num_ch
nind = tf.range(num_batch)
x = repeat(tf.range(height), width)
y = tf.tile(tf.range(width), tf.pack([height]))
cind = tf.range(num_ch)
nind = tf.expand_dims(repeat(nind, hwc), 1)
x = tf.tile(tf.expand_dims(repeat(x, num_ch), 1), tf.pack([num_batch,1]))
y = tf.tile(tf.expand_dims(repeat(y, num_ch), 1), tf.pack([num_batch,1]))
cind = tf.tile(tf.expand_dims(cind, 1), tf.pack([num_batch*height*width,1]))
dx, dy = tf.split(1, 2, theta)
dx = tf.cast(tf.clip_by_value(dx, 0, out_height-height), 'int32')
dx = tf.reshape(tf.tile(dx, tf.pack([1,hwc])), [-1,1])
dy = tf.cast(tf.clip_by_value(dy, 0, out_width-width), 'int32')
dy = tf.reshape(tf.tile(dy, tf.pack([1,hwc])), [-1,1])
x = x + dx
y = y + dy
tind = tf.concat(1, [nind, x, y, cind])
val = tf.reshape(U, [-1])
T = tf.sparse_to_dense(tind,
tf.pack([num_batch, out_height, out_width, num_ch]),
val)
T.set_shape([None, out_height, out_width, num_ch])
return T
评论列表
文章目录