mlp_training.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:SelfDrivingRCCar 作者: sidroopdaska 项目源码 文件源码
def retrieve_data_set():
    """Retrieve data from all the .npz files and aggregate it into a
    data set for mlp training"""

    start_time = cv2.getTickCount()

    print("Loading data set...")

    image_array = np.zeros((1, 38400), 'float')
    label_array = np.zeros((1, 4), 'float')

    # Retrieve a list of pathname that matches the below expr
    data_set = glob.glob("data_set/*.npz")

    if not data_set:
        print("No data set in directory, exiting!")
        sys.exit()

    for single_npz in data_set:
        with np.load(single_npz) as data:
            temp_images = data["images"]
            temp_labels = data["labels"]

        image_array = np.vstack((image_array, temp_images))
        label_array = np.vstack((label_array, temp_labels))

    X = np.float32(image_array[1:, :])
    Y = np.float32(label_array[1:, :])
    print("Image array shape: {0}".format(X.shape))
    print("Label array shape: {0}".format(Y.shape))

    end_time = cv2.getTickCount()
    print("Data set load duration: {0}"
          .format((end_time - start_time) // cv2.getTickFrequency()))

    return X, Y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号