data_handler.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:began 作者: davidtellez 项目源码 文件源码
def get_batch(self, subset, batch_size, use_target_distribution=False):

        # Select a subset
        if subset == 'train':
            X = self.X_train
            y = self.y_train
        elif subset == 'valid':
            X = self.X_val
            y = self.y_val
        elif subset == 'test':
            X = self.X_test
            y = self.y_test

        # Random choice of samples
        idx = np.random.choice(X.shape[0], batch_size)
        batch = X[idx, 0, :].reshape((batch_size, 28, 28))

        # Resize from 28x28 to 64x64
        batch_resized = []
        factor = self.image_size / 28.0
        for i in range(batch.shape[0]):
            # resize to 64x64 pixels
            batch_resized.append(scipy.ndimage.zoom(batch[i, :, :], factor, order=1))
        batch = np.stack(batch_resized)
        batch = batch.reshape((batch_size, 1, self.image_size, self.image_size))

        # Convert to RGB
        batch = np.concatenate([batch, batch, batch], axis=1)

        # Modify images if target distribution requested
        if use_target_distribution:

            # Binarize images
            batch[batch >= 0.5] = 1
            batch[batch < 0.5] = 0

            # For each image in the mini batch
            for i in range(batch_size):

                # Take a random crop of the Lena image (background)
                x_c = np.random.randint(0, self.lena.size[0] - self.image_size)
                y_c = np.random.randint(0, self.lena.size[1] - self.image_size)
                image = self.lena.crop((x_c, y_c, x_c + self.image_size, y_c + self.image_size))
                image = np.asarray(image).transpose((2, 0, 1)) / 255.0

                # Randomly alter the color distribution of the crop
                for j in range(3):
                    image[j, :, :] = (image[j, :, :] + np.random.uniform(0, 1)) / 2.0

                # Invert the color of pixels where there is a number
                image[batch[i, :, :, :] == 1] = 1 - image[batch[i, :, :, :] == 1]
                batch[i, :, :, :] = image

        # Rescale to range [-1, +1]
        # batch = batch * 2 - 1

        # Image label
        labels = y[idx]

        return batch.astype('float32'), labels.astype('int32')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号