myalexnet_feature.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:visual-search 作者: GYXie 项目源码 文件源码
def main():
    x, fc6 = initModel()
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    img_names = load_image_names(args.input_data_dir)

    with open(args.output_image_name_file, 'w') as img_names_file:
        for img_name in img_names:
            img_names_file.write(img_name + '\n')

    t = time.time()
    # ???????????
    batch_size = 100
    features = []

    with open(args.output_feature_file, 'w') as output_file:
        for i in range(0, int(math.ceil(len(img_names) / (batch_size * 1.0)))):
            print('batch: %d' % i)
            if (i + 1) * batch_size < len(img_names):
                img_names_batch = img_names[i * batch_size:(i + 1) * batch_size]
            else:
                img_names_batch = img_names[i * batch_size:len(img_names)]
            img_batch = load_images(img_names_batch)
            output = sess.run(fc6, feed_dict={x: img_batch})
            features.append(output)
        features = np.vstack(features)
        # binarizer = preprocessing.Binarizer().fit(features)
        # features = binarizer.transform(features)
        np.save(output_file, features)

    # with open('fc6.npy', 'w') as output_file:
    #     for i in range(0, int(math.ceil(len(imgs) / (batch_size * 1.0)))):
    #         print('batch: %d' % i)
    #         if (i + 1) * batch_size < len(imgs):
    #             img_batch = imgs[i * batch_size:(i + 1) * batch_size]
    #         else:
    #             img_batch = imgs[i * batch_size: len(imgs)]
    #         output = sess.run(fc6, feed_dict={x: img_batch})
    #         features.append(output)
    #     features = np.vstack(features)
    #     np.save(output_file, features)

    print(time.time() - t)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号