def train(trainset, valset, path_data, path_session, hyper_param):
"""Execute a single training task.
Returns:
model: /path/to/best_model as measured by validation's loss
loss: the loss computed on the validation set
acc: the accuracy computed on the validation set
"""
session_id = os.path.basename(path_session)
model_cp = keras.callbacks.ModelCheckpoint(
os.path.join(path_session, "{}_model.hdf5".format(session_id)),
monitor="val_loss",
save_best_only=True)
# train
model = _get_model(hyper_param["optimizer"],
hyper_param["batch_norm"],
pool_type=hyper_param["pool_type"],
dropout_rate=hyper_param["dropout_rate"])
history = model.fit_generator(
_sample_generator(trainset, path_data, hyper_param["batch_sz"]),
steps_per_epoch=int(len(trainset) / hyper_param["batch_sz"]),
epochs=hyper_param["epochs"],
validation_data=_sample_generator(valset, path_data, 2),
validation_steps=int(len(valset) / 2),
callbacks=[model_cp, hyper_param["lr_schedule"]],
verbose=1,
workers=4)
# plot training curves
def plot_history(metric):
plt.ioff()
str_metric = "accuracy" if metric == "acc" else "loss"
plt.plot(history.history[metric])
plt.plot(history.history["val_{}".format(metric)])
plt.title("model {}".format(str_metric))
plt.ylabel(str_metric)
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(os.path.join(path_session,
"{}_{}.png".format(session_id, str_metric)))
plot_history("loss")
plt.cla()
plot_history("acc")
with open(os.path.join(path_session,
"{}_history.pkl".format(session_id)),
'wb') as output:
pickle.dump(history.history, output, pickle.HIGHEST_PROTOCOL)
# output model and performance measures
ind_min_loss = np.argmin(history.history["val_loss"])
return (os.path.join(path_session, "{}.hdf5".format(session_id)),
history.history["val_loss"][ind_min_loss],
history.history["val_acc"][ind_min_loss])
评论列表
文章目录