optimise.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tdlstm 作者: bluemonk482 项目源码 文件源码
def TUNE(args, model, mode, n_calls=5):
    hyperparameters_all = {
        'batch_size': range(40, 130, 20),
        'seq_len': [42],
        'num_hidden': np.random.randint(100, 501, 10),
        'learning_rate': [0.0005],
        'dropout_output': np.arange(0.3, 1.1, 0.1),
        'dropout_input': np.arange(0.3, 1.1, 0.1),
        'clip_norm': np.arange(0.5, 1.01, 0.1),
    }

    maxx = 0
    data = load_data(args, args.data, saved=args.load_data)
    if mode == 'rand':
        samp = random_search(hyperparameters_all, n_calls) #random search
    else:
        samp = expand_grid(hyperparameters_all) #grid-search
    for hyperparameters in samp:
        print("Evaluating hyperparameters:", hyperparameters)
        for attr, value in hyperparameters.items():
            setattr(args, attr, value)
        scores = run_network(args, data, model, tuning=args.tune)
        test_score, eval_score = scores
        if eval_score[0] > maxx:
            maxx = eval_score[0]
            best_score = test_score
            hyperparameters_best = hyperparameters
        tf.reset_default_graph()
    print()
    print("Optimisation finished..")
    print("Optimised hyperparameters:")
    with open(os.path.dirname(args.checkpoint_file)+'/checkpoint', 'w') as fp:
        fp.write('%s:"%s"\n' % ('model',args.model))
        for attr, value in sorted(hyperparameters_best.items()):
            print("{}={}".format(attr.upper(), value))
            fp.write('%s:"%s"\n' % (attr, value))
    print()
    print("Final Test Data Accuracy = {:.5f}; 3-class F1 = {:.5f}; 2-class F1 = {:.5f}"
                          .format(best_score[0], best_score[1], best_score[2]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号