frontier2.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:learning-to-prune 作者: timvieira 项目源码 文件源码
def learning_curve_handler(_args, args, log, jobid):
    # Note: iterations appear to start at 1.

    show_each = _args.show_each
    if show_each:
        ax1 = ax2 = pl.figure().add_subplot(111)

        [[ro],[acc],[run]] = [np.unique(args['args_roll_out']),
                              np.unique(args['args_accuracy']),
                              np.unique(args['args_runtime'])]

        ax1.set_title('jobid: %s. ro/acc/run %s/%s/%s'
                      % (jobid, ro, acc, run))
    else:
        if 0:
            ax1 = AX['trainlc']
            ax2 = AX['devlc']
            ax1.set_title('lc train')
            ax2.set_title('lc dev')
        else:
            # group learning curves by regularizer
            col = 'args_C'
            ax1 = AX['trainlc-%s' % args.get(col)]
            ax2 = AX['devlc-%s' % args.get(col)]
            ax1.set_title('lc train %s' % args.get(col))
            ax2.set_title('lc dev %s' % args.get(col))

    # Pick x-axis time or iterations.
    #X = log.iteration
    X = log['elapsed']
    ax1.set_xlabel('days')
    ax2.set_xlabel('days')

    if log.get('train_accuracy') is not None:
        #ax1.plot(log.iteration, log.train_accuracy - log.tradeoff * log.train_runtime, alpha=1, c='b')
        ax1.plot(X, log.train_new_policy_reward, alpha=1, c='b')
        maxes = running_max(list(X), list(log.train_new_policy_reward))
        ax1.scatter(maxes[:,0], maxes[:,1], lw=0)
    if log.get('dev_accuracy') is not None:
        #ax2.plot(X, log.dev_accuracy - log.tradeoff * log.dev_runtime, alpha=1, c='r')
        #patience(log, ax2)
        ax2.plot(X, log.dev_new_policy_reward, alpha=1, c='r')
        maxes = running_max(list(X), list(log.dev_new_policy_reward))
        ax2.scatter(maxes[:,0], maxes[:,1], lw=0)

    if show_each:
        pl.ioff()
        pl.show()

    if _args.kill_mode:
        if raw_input('kill?').startswith('y'):
            KILL.append(jobid)
            print 'KILL', ' '.join(KILL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号