def _transform(self, theta, input_dim, out_size, channel_input):
with tf.variable_scope('_transform'):
print input_dim.get_shape(), theta.get_shape(), out_size[0], out_size[1]
num_batch = self.hps.batch_size #tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
grid = self._meshgrid(out_height, out_width)
#print grid, grid.get_shape()
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.pack([num_batch]))
grid = tf.reshape(grid, tf.pack([num_batch, 3, -1]))
#print grid, grid.get_shape()
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = tf.batch_matmul(theta, grid)
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
x_s_flat = tf.reshape(x_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
#print x_s_flat.get_shape(), y_s_flat.get_shape()
input_transformed = self._interpolate(input_dim, x_s_flat, y_s_flat, out_size, channel_input)
#print input_transformed.get_shape()
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, channel_input]))
return output
#return input_dim
评论列表
文章目录