def add_channels(self):
n_channels = self.n_channels
if n_channels == 1:
super().add_channels()
else:
X = self.X
if X.ndim < 4: # if X.dim == 4, no need to add a channel rank.
N, img_rows, img_cols = X.shape
if K.image_dim_ordering() == 'th':
X = X.reshape(X.shape[0], 1, img_rows, img_cols)
X = np.concatenate([X, X, X], axis=1)
input_shape = (n_channels, img_rows, img_cols)
else:
X = X.reshape(X.shape[0], img_rows, img_cols, 1)
X = np.concatenate([X, X, X], axis=3)
input_shape = (img_rows, img_cols, n_channels)
else:
if K.image_dim_ordering() == 'th':
N, Ch, img_rows, img_cols = X.shape
if Ch == 1:
X = np.concatenate([X, X, X], axis=1)
input_shape = (n_channels, img_rows, img_cols)
else:
N, img_rows, img_cols, Ch = X.shape
if Ch == 1:
X = np.concatenate([X, X, X], axis=3)
input_shape = (img_rows, img_cols, n_channels)
if self.preprocessing_flag:
X = preprocess_input(X)
self.X = X
self.input_shape = input_shape
# self.img_info = {'channels': n_channels,
# 'rows': img_rows, 'cols': img_cols}
评论列表
文章目录