def _maybe_download_and_extract(self):
"""Download and extract the MNIST dataset"""
data_sets = mnist.read_data_sets(
self._data_dir,
dtype=tf.uint8,
reshape=False,
validation_size=self._num_examples_per_epoch_for_eval)
# Convert to Examples and write the result to TFRecords.
if not tf.gfile.Exists(os.path.join(self._data_dir, 'train.tfrecords')):
convert_to_tfrecords(data_sets.train, 'train', self._data_dir)
if not tf.gfile.Exists(
os.path.join(self._data_dir, 'validation.tfrecords')):
convert_to_tfrecords(data_sets.validation, 'validation',
self._data_dir)
if not tf.gfile.Exists(os.path.join(self._data_dir, 'test.tfrecords')):
convert_to_tfrecords(data_sets.test, 'test', self._data_dir)
评论列表
文章目录