test_regularizers.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras 作者: NVIDIA 项目源码 文件源码
def get_data():
    # the data, shuffled and split between tran and test sets
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = X_train.reshape(60000, 784)[:max_train_samples]
    X_test = X_test.reshape(10000, 784)[:max_test_samples]
    X_train = X_train.astype("float32") / 255
    X_test = X_test.astype("float32") / 255

    # convert class vectors to binary class matrices
    y_train = y_train[:max_train_samples]
    y_test = y_test[:max_test_samples]
    Y_train = np_utils.to_categorical(y_train, nb_classes)
    Y_test = np_utils.to_categorical(y_test, nb_classes)
    test_ids = np.where(y_test == np.array(weighted_class))[0]

    return (X_train, Y_train), (X_test, Y_test), test_ids
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号