classifier.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:kaggle_dsb 作者: syagev 项目源码 文件源码
def train(trainset, valset, path_data, path_session, hyper_param):
    """Execute a single training task.


    Returns:
        model: /path/to/best_model as measured by validation's loss
        loss: the loss computed on the validation set
        acc: the accuracy computed on the validation set
    """

    session_id = os.path.basename(path_session)
    model_cp = keras.callbacks.ModelCheckpoint(
        os.path.join(path_session, "{}_model.hdf5".format(session_id)),
        monitor="val_loss",
        save_best_only=True)


    # train
    model = _get_model(hyper_param["optimizer"],
                       hyper_param["batch_norm"],
                       pool_type=hyper_param["pool_type"],
                       dropout_rate=hyper_param["dropout_rate"])
    history = model.fit_generator(
        _sample_generator(trainset, path_data, hyper_param["batch_sz"]),
        steps_per_epoch=int(len(trainset) / hyper_param["batch_sz"]),
        epochs=hyper_param["epochs"],
        validation_data=_sample_generator(valset, path_data, 2),
        validation_steps=int(len(valset) / 2),
        callbacks=[model_cp, hyper_param["lr_schedule"]],
        verbose=1,
        workers=4)

    # plot training curves
    def plot_history(metric):
        plt.ioff()
        str_metric = "accuracy" if metric == "acc" else "loss"
        plt.plot(history.history[metric])
        plt.plot(history.history["val_{}".format(metric)])
        plt.title("model {}".format(str_metric))
        plt.ylabel(str_metric)
        plt.xlabel("epoch")
        plt.legend(["train", "test"], loc="upper left")
        plt.savefig(os.path.join(path_session, 
                                 "{}_{}.png".format(session_id, str_metric)))

    plot_history("loss")
    plt.cla()
    plot_history("acc")   
    with open(os.path.join(path_session,
                           "{}_history.pkl".format(session_id)),
              'wb') as output:
        pickle.dump(history.history, output, pickle.HIGHEST_PROTOCOL)

    # output model and performance measures
    ind_min_loss = np.argmin(history.history["val_loss"])
    return (os.path.join(path_session, "{}.hdf5".format(session_id)),
            history.history["val_loss"][ind_min_loss],
            history.history["val_acc"][ind_min_loss])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号