extractFeatures.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:WechatForwardBot 作者: grapeot 项目源码 文件源码
def main(argv):
    inputfile = ''
    outputfile = ''

    try:
        opts, args = getopt.getopt(argv,"hi:o:",["ifile=","ofile="])
    except getopt.GetoptError:
        print('caffe_feature_extractor.py -i <inputfile> -o <outputfile>')
        sys.exit(2)

    for opt, arg in opts:
        if opt == '-h':
            print('caffe_feature_extractor.py -i <inputfile> -o <outputfile>')
            sys.exit()
        elif opt in ("-i"):
            inputfile = arg
        elif opt in ("-o"):
            outputfile = arg

    print('Reading images from "', inputfile)
    print('Writing vectors to "', outputfile)

    # Setting this to CPU, but feel free to use GPU if you have CUDA installed
    caffe.set_mode_cpu()
    # Loading the Caffe model, setting preprocessing parameters
    net = caffe.Classifier(model_prototxt, model_trained,
                           mean=np.load(mean_path).mean(1).mean(1),
                           channel_swap=(2,1,0),
                           raw_scale=255,
                           image_dims=(256, 256))

    # Loading class labels
    with open(imagenet_labels) as f:
        labels = f.readlines()

    # This prints information about the network layers (names and sizes)
    # You can uncomment this, to have a look inside the network and choose which layer to print
    #print [(k, v.data.shape) for k, v in net.blobs.items()]
    #exit()

    # Processing one image at a time, printint predictions and writing the vector to a file
    with open(inputfile, 'r') as reader:
        with open(outputfile, 'w') as writer:
            writer.truncate()
            for image_path in reader:
                try:
                    image_path = image_path.strip()
                    with open(image_path, 'rb') as fp:
                        cachekey = hashlib.sha224(fp.read()).hexdigest()
                    input_image = caffe.io.load_image(image_path)
                    prediction = net.predict([input_image], oversample=False)
                    print(os.path.basename(image_path), ' : ' , labels[prediction[0].argmax()].strip() , ' (', prediction[0][prediction[0].argmax()] , ')')
                    feature = net.blobs[layer_name].data[0].reshape(1,-1)
                    featureTxt = ' '.join([ str(x) for x in feature.tolist()[0] ])
                    writer.write('{0}\t{1}\t{2}\n'.format(image_path, cachekey, featureTxt))
                except Exception as e:
                    print(e)
                    print('ERROR: skip {0}.'.format(image_path))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号