def get_symbol_classifications(symbols):
if os.environ["IS_GPU"]:
caffe.set_device(0)
caffe.set_mode_gpu()
else:
caffe.set_mode_cpu()
classifier = caffe.Classifier(os.path.join(os.environ["JAR_NOJAR_MODELS_DIR"], "deploy.prototxt"),
os.path.join(os.environ["JAR_NOJAR_MODELS_DIR"], "weights.caffemodel"),
image_dims=[64, 64],
raw_scale=255.0)
LOGGER.info("Classifying " + str(len(symbols)) + " inputs.")
predictions = classifier.predict([s[1] for s in symbols])
symbol_sequence = list()
classes = np.array([0, 1])
for i, prediction in enumerate(predictions):
idx = list((-prediction).argsort())
prediction = classes[np.array(idx)]
if prediction[0] == 1:
symbol_sequence.append([symbols[i], "jar"])
elif prediction[0] == 0:
symbol_sequence.append([symbols[i], "no-jar"])
return symbol_sequence
symbol_classification.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录