def debug_updated_weights(opts, steps, weights, data):
""" Various debug plots for updated weights of training points.
"""
assert data.num_points == len(weights), 'Length mismatch'
ws_and_ids = sorted(zip(weights,
range(len(weights))))
num_plot = 20 * 16
if num_plot > len(weights):
return
ids = [_id for w, _id in ws_and_ids[:num_plot]]
plot_points = data.data[ids]
metrics = metrics_lib.Metrics()
metrics.make_plots(opts, steps,
None, plot_points,
prefix='d_least_')
ids = [_id for w, _id in ws_and_ids[-num_plot:]]
plot_points = data.data[ids]
metrics = metrics_lib.Metrics()
metrics.make_plots(opts, steps,
None, plot_points,
prefix='d_most_')
plt.clf()
ax1 = plt.subplot(211)
ax1.set_title('Weights over data points')
plt.plot(range(len(weights)), sorted(weights))
plt.axis([0, len(weights), 0., 2. * np.max(weights)])
if data.labels is not None:
all_labels = np.unique(data.labels)
w_per_label = -1. * np.ones(len(all_labels))
for _id, y in enumerate(all_labels):
w_per_label[_id] = np.sum(
weights[np.where(data.labels == y)[0]])
ax2 = plt.subplot(212)
ax2.set_title('Weights over labels')
plt.scatter(range(len(all_labels)), w_per_label, s=30)
filename = 'data_w{:02d}.png'.format(steps)
create_dir(opts['work_dir'])
plt.savefig(o_gfile((opts["work_dir"], filename), 'wb'))
评论列表
文章目录