def read_data_sets(train_dir,source_data=Source.NUMBER_IMAGES, fake_data=False, one_hot=True):
class DataSets(object):
pass
data_sets = DataSets()
if fake_data:
data_sets.train = DataSet([], [], fake_data=True, one_hot=one_hot)
data_sets.validation = DataSet([], [], fake_data=True, one_hot=one_hot)
data_sets.test = DataSet([], [], fake_data=True, one_hot=one_hot)
return data_sets
VALIDATION_SIZE = 2000
local_file = maybe_download(source_data, train_dir)
train_images = extract_images(TRAIN_INDEX,train=True)
train_labels = extract_labels(TRAIN_INDEX,train=True, one_hot=one_hot)
test_images = extract_images(TEST_INDEX,train=False)
test_labels = extract_labels(TEST_INDEX,train=False, one_hot=one_hot)
# train_images = train_images[:VALIDATION_SIZE]
# train_labels = train_labels[:VALIDATION_SIZE:]
# test_images = test_images[VALIDATION_SIZE:]
# test_labels = test_labels[VALIDATION_SIZE:]
data_sets.train = DataSet(train_images, train_labels , load=False)
data_sets.test = DataSet(test_images, test_labels, load=True)
# data_sets.validation = DataSet(validation_images, validation_labels, load=True)
return data_sets
评论列表
文章目录