symbol_classification.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:indus-script-ocr 作者: tpsatish95 项目源码 文件源码
def get_symbol_classifications(symbols):
    if os.environ["IS_GPU"]:
        caffe.set_device(0)
        caffe.set_mode_gpu()
    else:
        caffe.set_mode_cpu()

    classifier = caffe.Classifier(os.path.join(os.environ["JAR_NOJAR_MODELS_DIR"], "deploy.prototxt"),
                                  os.path.join(os.environ["JAR_NOJAR_MODELS_DIR"], "weights.caffemodel"),
                                  image_dims=[64, 64],
                                  raw_scale=255.0)

    LOGGER.info("Classifying " + str(len(symbols)) + " inputs.")

    predictions = classifier.predict([s[1] for s in symbols])

    symbol_sequence = list()
    classes = np.array([0, 1])

    for i, prediction in enumerate(predictions):
        idx = list((-prediction).argsort())
        prediction = classes[np.array(idx)]

        if prediction[0] == 1:
            symbol_sequence.append([symbols[i], "jar"])
        elif prediction[0] == 0:
            symbol_sequence.append([symbols[i], "no-jar"])

    return symbol_sequence
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号