def debug_mixture_classifier(opts, step, probs, points, num_plot=320, real=True):
"""Small debugger for the mixture classifier's output.
"""
num = len(points)
if len(probs) != num:
return
if num < 2 * num_plot:
return
sorted_vals_and_ids = sorted(zip(probs, range(num)))
if real:
correct = sorted_vals_and_ids[-num_plot:]
wrong = sorted_vals_and_ids[:num_plot]
else:
correct = sorted_vals_and_ids[:num_plot]
wrong = sorted_vals_and_ids[-num_plot:]
correct_ids = [_id for val, _id in correct]
wrong_ids = [_id for val, _id in wrong]
idstring = 'real' if real else 'fake'
logging.debug('Correctly classified %s points probs:' %\
idstring)
logging.debug([val[0] for val, _id in correct])
logging.debug('Incorrectly classified %s points probs:' %\
idstring)
logging.debug([val[0] for val, _id in wrong])
metrics = metrics_lib.Metrics()
metrics.make_plots(opts, step,
None, points[correct_ids],
prefix='c_%s_correct_' % idstring)
metrics.make_plots(opts, step,
None, points[wrong_ids],
prefix='c_%s_wrong_' % idstring)
评论列表
文章目录