def predict(limit):
_limit = limit if limit > 0 else 5
td = TrainingData(LABEL_FILE, img_root=IMAGES_ROOT, mean_image_file=MEAN_IMAGE_FILE, image_property=IMAGE_PROP)
label_def = LabelingMachine.read_label_def(LABEL_DEF_FILE)
model = alex.Alex(len(label_def))
serializers.load_npz(MODEL_FILE, model)
i = 0
for arr, im in td.generate():
x = np.ndarray((1,) + arr.shape, arr.dtype)
x[0] = arr
x = chainer.Variable(np.asarray(x), volatile="on")
y = model.predict(x)
p = np.argmax(y.data)
print("predict {0}, actual {1}".format(label_def[p], label_def[im.label]))
im.image.show()
i += 1
if i >= _limit:
break
评论列表
文章目录