validation.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:cnn_basics 作者: kootenpv 项目源码 文件源码
def validate(model, X, y, nb_epoch=25, batch_size=128,
             stop_early=True, folds=10, test_size=None, shuffle=True, verbose=True):
    early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=0, mode='auto')

    total_score = []
    if test_size is None:
        if folds == 1:
            test_size = 0.25
        else:
            test_size = 1 - (1. / folds)
    kf = ShuffleSplit(n_splits=folds, test_size=test_size)
    for fold, (train_index, test_index) in enumerate(kf.split(X, y)):
        shuffle_weights(model)
        if fold > 0:
            print("FOLD:", fold)
            print("-" * 40)
            model.reset_states()
            callbacks = [early_stopping] if True else None
        hist = model.fit(X[train_index], y[train_index], batch_size=batch_size, shuffle=shuffle,
                         validation_data=(X[test_index], y[test_index]),
                         callbacks=[early_stopping], verbose=verbose)
        total_score.append(hist.history["val_acc"][-1])
    return np.mean(total_score)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号