utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def debug_updated_weights(opts, steps, weights, data):
    """ Various debug plots for updated weights of training points.

    """
    assert data.num_points == len(weights), 'Length mismatch'
    ws_and_ids = sorted(zip(weights,
                        range(len(weights))))
    num_plot = 20 * 16
    if num_plot > len(weights):
        return
    ids = [_id for w, _id in ws_and_ids[:num_plot]]
    plot_points = data.data[ids]
    metrics = metrics_lib.Metrics()
    metrics.make_plots(opts, steps,
                       None, plot_points,
                       prefix='d_least_')
    ids = [_id for w, _id in ws_and_ids[-num_plot:]]
    plot_points = data.data[ids]
    metrics = metrics_lib.Metrics()
    metrics.make_plots(opts, steps,
                       None, plot_points,
                       prefix='d_most_')
    plt.clf()
    ax1 = plt.subplot(211)
    ax1.set_title('Weights over data points')
    plt.plot(range(len(weights)), sorted(weights))
    plt.axis([0, len(weights), 0., 2. * np.max(weights)])
    if data.labels is not None:
        all_labels = np.unique(data.labels)
        w_per_label = -1. * np.ones(len(all_labels))
        for _id, y in enumerate(all_labels):
            w_per_label[_id] = np.sum(
                    weights[np.where(data.labels == y)[0]])
        ax2 = plt.subplot(212)
        ax2.set_title('Weights over labels')
        plt.scatter(range(len(all_labels)), w_per_label, s=30)
    filename = 'data_w{:02d}.png'.format(steps)
    create_dir(opts['work_dir'])
    plt.savefig(o_gfile((opts["work_dir"], filename), 'wb'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号