util.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chess-deep-rl 作者: rajpurkar 项目源码 文件源码
def train(net_type, generator_fn_str, dataset_file, build_net_fn, featurized=True):
    d = Dataset(dataset_file + 'train.pgn')
    generator_fn = getattr(d, generator_fn_str)
    d_test = Dataset(dataset_file + 'test.pgn')

    X_val, y_val = d_test.load(generator_fn.__name__,
        featurized = featurized,
        refresh    = False,
        board = net_type)

    board_num_channels = X_val[0].shape[1] if net_type == 'to' else X_val[0].shape[0]
    model = build_net_fn(board_num_channels=board_num_channels, net_type=net_type)
    start_time = str(int(time.time()))
    try:
        plot_model(model, start_time, net_type)
    except:
        print("Skipping plot")
    from keras.callbacks import ModelCheckpoint
    checkpointer = ModelCheckpoint(
        filepath       = get_filename_for_saving(start_time, net_type),
        verbose        = 2,
        save_best_only = True)

    model.fit_generator(generator_fn(featurized=featurized, board=net_type),
        samples_per_epoch = SAMPLES_PER_EPOCH,
        nb_epoch          = NUMBER_EPOCHS,
        callbacks         = [checkpointer],
        validation_data   = (X_val, y_val),
        verbose           = VERBOSE_LEVEL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号