pose_model.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:Face-Pose-Net 作者: fengju514 项目源码 文件源码
def _transform(self, theta, input_dim, out_size, channel_input):
    with tf.variable_scope('_transform'):
      print input_dim.get_shape(), theta.get_shape(), out_size[0], out_size[1]
      num_batch = self.hps.batch_size #tf.shape(input_dim)[0]
      height = tf.shape(input_dim)[1]
      width = tf.shape(input_dim)[2]
      num_channels = tf.shape(input_dim)[3]
      theta = tf.reshape(theta, (-1, 2, 3))
      theta = tf.cast(theta, 'float32')

      # grid of (x_t, y_t, 1), eq (1) in ref [1]
      height_f = tf.cast(height, 'float32')
      width_f = tf.cast(width, 'float32')
      out_height = out_size[0]
      out_width = out_size[1]
      grid = self._meshgrid(out_height, out_width)
      #print grid, grid.get_shape()
      grid = tf.expand_dims(grid, 0)
      grid = tf.reshape(grid, [-1])
      grid = tf.tile(grid, tf.pack([num_batch]))
      grid = tf.reshape(grid, tf.pack([num_batch, 3, -1]))
      #print grid, grid.get_shape()

      # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
      T_g = tf.batch_matmul(theta, grid)
      x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
      y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
      x_s_flat = tf.reshape(x_s, [-1])
      y_s_flat = tf.reshape(y_s, [-1])
      #print x_s_flat.get_shape(), y_s_flat.get_shape()
      input_transformed = self._interpolate(input_dim, x_s_flat, y_s_flat, out_size, channel_input)
      #print input_transformed.get_shape()

      output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, channel_input]))
      return output
      #return input_dim
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号