resnet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:convolutional-vqa 作者: paarthneekhara 项目源码 文件源码
def create_resnet_model(img_dim):
    pre_image = tf.placeholder(tf.float32, [None, None, 3])
    processed_image = cnn_preprocessing.preprocess_for_eval(pre_image/255.0, img_dim, img_dim)

    images = tf.placeholder(tf.float32, [None, img_dim, img_dim, 3])
    # mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
    # processed_images = images - mean
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
        probs, endpoints = resnet_v2.resnet_v2_152(images, num_classes=1001, is_training = False)
        print endpoints['resnet_v2_152/block4']

    init_fn = slim.assign_from_checkpoint_fn(
            'Data/CNNModels/resnet_v2_152.ckpt',
            slim.get_model_variables('resnet_v2_152'))

    sess = tf.Session()
    init_fn(sess)

    return {
        'images_placeholder' : images,
        'block4' : endpoints['resnet_v2_152/block4'],
        'session' : sess,
        'processed_image' : processed_image,
        'pre_image' : pre_image,
        'probs' : probs
    }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号