def spatial_transformer(U, theta, out_height, out_width):
num_batch = tf.shape(U)[0]
height, width, num_channels = U.get_shape()[1:]
x_t, y_t = meshgrid(out_height, out_width)
x_t = tf.expand_dims(x_t, 0)
y_t = tf.expand_dims(y_t, 0)
if theta.get_shape()[1] == 3:
s, t_x, t_y = tf.split(1, 3, theta)
x_s = tf.reshape(s*tf.tile(x_t, [num_batch,1]) + t_x, [-1])
y_s = tf.reshape(s*tf.tile(y_t, [num_batch,1]) + t_y, [-1])
else:
grid = tf.expand_dims(tf.concat(0, [x_t, y_t, tf.ones_like(x_t)]), 0)
grid = tf.tile(grid, [num_batch,1,1])
grid_t = tf.batch_matmul(tf.reshape(theta, [-1,2,3]), grid)
x_s = tf.reshape(tf.slice(grid_t, [0,0,0], [-1,1,-1]), [-1])
y_s = tf.reshape(tf.slice(grid_t, [0,1,0], [-1,1,-1]), [-1])
return transform(U, x_s, y_s, num_batch, out_height, out_width, num_channels)
# last layer of localization net
评论列表
文章目录