def train(net_type, generator_fn_str, dataset_file, build_net_fn, featurized=True):
d = Dataset(dataset_file + 'train.pgn')
generator_fn = getattr(d, generator_fn_str)
d_test = Dataset(dataset_file + 'test.pgn')
X_val, y_val = d_test.load(generator_fn.__name__,
featurized = featurized,
refresh = False,
board = net_type)
board_num_channels = X_val[0].shape[1] if net_type == 'to' else X_val[0].shape[0]
model = build_net_fn(board_num_channels=board_num_channels, net_type=net_type)
start_time = str(int(time.time()))
try:
plot_model(model, start_time, net_type)
except:
print("Skipping plot")
from keras.callbacks import ModelCheckpoint
checkpointer = ModelCheckpoint(
filepath = get_filename_for_saving(start_time, net_type),
verbose = 2,
save_best_only = True)
model.fit_generator(generator_fn(featurized=featurized, board=net_type),
samples_per_epoch = SAMPLES_PER_EPOCH,
nb_epoch = NUMBER_EPOCHS,
callbacks = [checkpointer],
validation_data = (X_val, y_val),
verbose = VERBOSE_LEVEL)
评论列表
文章目录