data_iterators.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:multi-gpu-keras-tf 作者: sallamander 项目源码 文件源码
def _load_data(self, nb_obs=None):
        """Load the dataset specified by self.name

        :param nb_obs: optional; int for the number of observations to retain
         from the training & testing sets; if None, retain the full training
         and testing sets
        :return: a tuple of 4 np.ndarrays (x_train, y_train, x_test, y_test)
        """

        dataset = getattr(keras.datasets, self.name)
        train_data, test_data = dataset.load_data()
        x_train, y_train = train_data[0] / 255., train_data[1]
        x_test, y_test = test_data[0] / 255., test_data[1]

        y_train = to_categorical(y_train)
        y_test = to_categorical(y_test)

        if self.name == 'mnist':
            x_train = np.expand_dims(x_train, axis=-1)
            x_test = np.expand_dims(x_test, axis=-1)

        if nb_obs:
            x_train = x_train[:nb_obs]
            y_train = y_train[:nb_obs]

            x_test = x_test[:nb_obs]
            y_test = y_test[:nb_obs]

        return x_train, y_train, x_test, y_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号