aegans.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:jamespy_py3 作者: jskDr 项目源码 文件源码
def get_data(data_name='mnist', test_flag=False):
    if data_name == 'daudi':
        (X_train, y_train), (X_test, y_test) = daudi_load_data()
        if test_flag:
            X_train = X_test
        # approximately -0.2+1 to 0.2+1 --> -1. 1
        X_train = (X_train - 1.0) * 5.0
        X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    else:
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        if test_flag:
            X_train = X_test
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])

    return X_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号