models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:crossnet 作者: viibridges 项目源码 文件源码
def pixelnet_convs(inputs, num_class, is_training=True, reuse=False):
  num_batch = tf.shape(inputs)[0]
  height = tf.shape(inputs)[1]
  width = tf.shape(inputs)[2]

  with tf.variable_scope('vgg_16', reuse=reuse):
    net, hyperfeats = nets.vgg_like(inputs)
    tf.add_to_collection('last_conv', net)

  with tf.name_scope('hyper_columns'):
    if is_training:
      # sample pixels corresponding to the last feature elements
      h, w = net.get_shape().as_list()[1:3]
      trace_locations = ops.trace_locations_backward
    else:
      # sample pixels corresponding to the whole image
      h, w = [height, width]
      trace_locations = ops.trace_locations_forward

    X, Y = tf.meshgrid(tf.range(w), tf.range(h), indexing='xy')
    loc_x = tf.tile(tf.reshape(X, [1,-1]), [num_batch, 1])
    loc_y = tf.tile(tf.reshape(Y, [1,-1]), [num_batch, 1])

    locations = [trace_locations(loc_x, loc_y, [h, w], [tf.shape(feat)[1], tf.shape(feat)[2]]) 
        for feat in hyperfeats]
    net = ops.extract_values(hyperfeats, locations)
    hyperchannels = net.get_shape().as_list()[-1]

    net = tf.reshape(net, [num_batch, h, w, hyperchannels]) 
    tf.add_to_collection('hyper_column', net)

    return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号