def _save_data(which, X, y, data_source):
if data_source.lower() == 'mnist':
data_source = 'mnist'
else:
data_source = 'se'
if X.shape[0] != len(y):
raise TypeError("Length of data samples ({0}) was not identical "
"to length of labels ({1})".format(X.shape[0], len(y)))
# Convert to numpy array.
if not isinstance(X, np.ndarray):
X = np.array(X)
if not isinstance(y, np.ndarray):
y = np.array(y)
# Write feature_data
fname = resource_filename('sudokuextract.data', "{0}-{1}-data.gz".format(data_source, which))
with gzip.GzipFile(fname, mode='wb') as f:
np.save(f, X)
# Write labels
fname = resource_filename('sudokuextract.data', "{0}-{1}-labels.gz".format(data_source, which))
with gzip.GzipFile(fname, mode='wb') as f:
np.save(f, y)
评论列表
文章目录