utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:text-classification-with-convnets 作者: osmanbaskaya 项目源码 文件源码
def cross_validate(model, X, y, n_folds, batch_size, num_epoch, func_for_evaluation=None):

    # let's shuffle first.
    seed = 5
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)

    X = np.array(X)
    y = np.array(y)

    scores = np.zeros(n_folds)
    kf = KFold(len(y), n_folds=n_folds)
    for i, (train_index, test_index) in enumerate(kf):
        X_train, y_train = X[train_index, :], y[train_index]
        X_test, y_test = X[test_index, :], y[test_index]
        model.fit(X_train, y_train,
                  batch_size=batch_size,
                  nb_epoch=num_epoch)

        predictions = model.predict(X_test)
        score = func_for_evaluation(predictions[:, 0].tolist(), y_test)
        try:
            scores[i] = score[0]
        except IndexError:
            scores[i] = score


    print "{}-Fold cross validation score: {}".format(n_folds, scores.mean())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号