def inference(predictions_op, true_labels_op, display, sess):
""" Perform inference per batch on pre-trained model.
This function performs inference and computes the CER per utterance.
Args:
predictions_op: Prediction op
true_labels_op: True Labels op
display: print sample predictions if True
sess: default session to evaluate the ops.
Returns:
char_err_rate: list of CER per utterance.
"""
char_err_rate = []
# Perform inference of batch worth of data at a time.
[predictions, true_labels] = sess.run([predictions_op,
true_labels_op])
pred_label = sparse_to_labels(predictions[0][0])
actual_label = sparse_to_labels(true_labels)
for (label, pred) in zip(actual_label, pred_label):
char_err_rate.append(distance(label, pred)/len(label))
if display:
# Print sample responses
for i in range(ARGS.batch_size):
print(actual_label[i] + ' vs ' + pred_label[i])
return char_err_rate
评论列表
文章目录