def evaluate(graph, mels, label, mapping):
""" Check correctness of a file classification """
logging.info('Evaluating audio classification')
audio_feature = np.asanyarray(list(mels.flatten()), dtype=np.float32)
true_result = mapping[label]
x = graph.get_tensor_by_name('prefix/input:0')
y = graph.get_tensor_by_name('prefix/softmax_tensor:0')
with tf.Session(graph=graph) as sess:
# Note: we didn't initialize/restore anything, everything is stored in the graph_def
y_out = sess.run(y, feed_dict={
x: [audio_feature]
})
logging.info('true value:' + str(true_result))
logging.info('predicted value:' + str(y_out[0].argmax()))
logging.info('predictions:' + str(y_out))
if y_out[0].argmax() == true_result:
return True
else:
return False
评论列表
文章目录