spatial_transformer.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tf_practice 作者: juho-lee 项目源码 文件源码
def spatial_transformer(U, theta, out_height, out_width):
    num_batch = tf.shape(U)[0]
    height, width, num_channels = U.get_shape()[1:]

    x_t, y_t = meshgrid(out_height, out_width)
    x_t = tf.expand_dims(x_t, 0)
    y_t = tf.expand_dims(y_t, 0)
    if theta.get_shape()[1] == 3:
        s, t_x, t_y = tf.split(1, 3, theta)
        x_s = tf.reshape(s*tf.tile(x_t, [num_batch,1]) + t_x, [-1])
        y_s = tf.reshape(s*tf.tile(y_t, [num_batch,1]) + t_y, [-1])
    else:
        grid = tf.expand_dims(tf.concat(0, [x_t, y_t, tf.ones_like(x_t)]), 0)
        grid = tf.tile(grid, [num_batch,1,1])
        grid_t = tf.batch_matmul(tf.reshape(theta, [-1,2,3]), grid)
        x_s = tf.reshape(tf.slice(grid_t, [0,0,0], [-1,1,-1]), [-1])
        y_s = tf.reshape(tf.slice(grid_t, [0,1,0], [-1,1,-1]), [-1])

    return transform(U, x_s, y_s, num_batch, out_height, out_width, num_channels)

# last layer of localization net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号