infer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lighting-augmentation 作者: GemHunt 项目源码 文件源码
def get_classifier(model_name, crop_size):
    model_dir = model_name + '/'
    MODEL_FILE = model_dir + 'deploy.prototxt'
    PRETRAINED = model_dir + 'snapshot.caffemodel'
    meanFile = model_dir + 'mean.binaryproto'

    # Open mean.binaryproto file
    blob = caffe.proto.caffe_pb2.BlobProto()
    data = open(meanFile, 'rb').read()
    blob.ParseFromString(data)
    mean_arr = np.array(caffe.io.blobproto_to_array(blob)).reshape(1, crop_size, crop_size)
    print mean_arr.shape

    net = caffe.Classifier(MODEL_FILE, PRETRAINED, image_dims=(crop_size, crop_size), mean=mean_arr, raw_scale=255)
    return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号