def main(args):
infos = _get_classifier_model_info(args.model_version)
with tf.Graph().as_default():
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
with sess.as_default():
pnet, rnet, onet = mtcnn.create_mtcnn(sess, args.caffe_model_dir)
with tf.Graph().as_default():
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
with sess.as_default():
recognize = csair_classifier.create_classifier(sess, model_def=infos['model_def'],
image_size=int(infos['image_size']),
embedding_size=int(infos['embedding_size']),
nrof_classes=int(infos['nrof_classes']),
ckpt_dir=args.ckpt_dir)
conn = db_utils.open_connection()
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
bounding_boxes, points = mtcnn.detect_face(frame, 20, pnet, rnet, onet, args.threshold, args.factor)
if len(bounding_boxes) > 0:
for i in range(len(bounding_boxes)):
box = bounding_boxes[i].astype(int)
# mark = np.reshape(points[:, i].astype(int), (2, 5)).T
crop = cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
crop = cv2.resize(crop, (160, 160), interpolation=cv2.INTER_CUBIC)
crop = np.expand_dims(crop, 0)
value, index = csair_classifier.classify(crop, recognize)
font = cv2.FONT_HERSHEY_TRIPLEX
name = db_utils.get_candidate_info(conn, int(index[0][0]))[0]
text = 'person: ' + name + ' probability: ' + str(value[0][0])
# print('text: ', text)
cv2.putText(frame, text, (box[0], box[1]), font, 0.42, (255, 255, 0))
# for p in mark:
# cv2.circle(frame, (p[0], p[1]), 3, (0, 0, 255))
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
db_utils.close_connection(conn)
评论列表
文章目录