def load_batch(fpath, label_key='labels'):
f = open(fpath, 'rb')
if sys.version_info < (3,):
d = cPickle.load(f)
else:
d = cPickle.load(f, encoding="bytes")
# decode utf8
d_decoded = {}
for k, v in d.items():
d_decoded[k.decode("utf8")] = v
d = d_decoded
f.close()
data = d["data"]
labels = d[label_key]
data = data.reshape(data.shape[0], 3, 32, 32)
return data, labels
python类load()的实例源码
def load(self, config_data):
"""
Method to load the configuration file, the configuration schema, and select the correct validator and backend
Args:
config_data(dict): The configuration dictionary
Returns:
None
"""
self.config_data = config_data
# Load the schema file based on the config that was provided
try:
schema_name = self.config_data['schema']['name']
except KeyError as err:
raise ConfigFileError("The specified schema was not found: {}. Try to update your ingest client library or double check your ingest job configuration file".format(self.config_data['schema']['name']))
with open(os.path.join(resource_filename("ingestclient", "schema"), "{}.json".format(schema_name)), 'rt') as schema_file:
self.schema = json.load(schema_file)
def load_plugins(self):
"""Method to load the plugins
Returns:
None
"""
# Create plugin instances
package, class_name = self.config_data["client"]["tile_processor"]["class"].rsplit('.', 1)
tile_module = importlib.import_module(package)
tile_class = getattr(tile_module, class_name)
self.tile_processor_class = tile_class()
package, class_name = self.config_data["client"]["path_processor"]["class"].rsplit('.', 1)
path_module = importlib.import_module(package)
path_class = getattr(path_module, class_name)
self.path_processor_class = path_class()
def load_and_display_pickle(datasets, sample_size, title=None):
fig = plt.figure()
if title: fig.suptitle(title, fontsize=16, fontweight='bold')
num_of_images = []
for pickle_file in datasets:
with open(pickle_file, 'rb') as f:
data = pickle.load(f)
print('Total images in', pickle_file, ':', len(data))
for index, image in enumerate(data):
if index == sample_size: break
ax = fig.add_subplot(len(datasets), sample_size, sample_size * datasets.index(pickle_file) +
index + 1)
ax.imshow(image)
ax.set_axis_off()
ax.imshow(image)
num_of_images.append(len(data))
balance_check(num_of_images)
plt.show()
return num_of_images
def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""
# load the saved model
classifier = pickle.load(open('best_model.pkl'))
# compile a predictor function
predict_model = theano.function(
inputs=[classifier.input],
outputs=classifier.y_pred)
# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()
predicted_values = predict_model(test_set_x[:10])
print("Predicted values for the first 10 examples in test set:")
print(predicted_values)
def restore_snapshot(self, filename=None):
"""
Restore a saved snapshot of current process from file
Warning: this is not thread safe, do not use with multithread program
Args:
- file: saved snapshot
Returns:
- Bool
"""
if not filename:
filename = self.get_config_filename("snapshot")
fd = open(filename, "rb")
snapshot = pickle.load(fd)
return self.give_snapshot(snapshot)
#########################
# Memory Operations #
#########################
def read_dataset(data_dir):
pickle_filename = "MITSceneParsing.pickle"
pickle_filepath = os.path.join(data_dir, pickle_filename)
if not os.path.exists(pickle_filepath):
utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)
SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0]
result = create_image_lists(os.path.join(data_dir, SceneParsing_folder))
print ("Pickling ...")
with open(pickle_filepath, 'wb') as f:
pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
else:
print ("Found pickle file!")
with open(pickle_filepath, 'rb') as f:
result = pickle.load(f)
training_records = result['training']
validation_records = result['validation']
del result
return training_records, validation_records
def str_to_func(s, sandbox=None):
if isinstance(s, (tuple, list)):
code, closure, defaults = s
elif isinstance(s, string_types): # path to file
if os.path.isfile(s):
with open(s, 'rb') as f:
code, closure, defaults = cPickle.load(f)
else: # pickled string
code, closure, defaults = cPickle.loads(s)
else:
raise ValueError("Unsupport str_to_func for type:%s" % type(s))
code = marshal.loads(cPickle.loads(code).tostring())
func = types.FunctionType(code=code, name=code.co_name,
globals=sandbox if isinstance(sandbox, Mapping) else globals(),
closure=closure, argdefs=defaults)
return func
def load_npy_to_any(path='', name='file.npy'):
"""Load .npy file.
Examples
---------
- see save_any_to_npy()
"""
file_path = os.path.join(path, name)
try:
npy = np.load(file_path).item()
except:
npy = np.load(file_path)
finally:
try:
return npy
except:
print("[!] Fail to load %s" % file_path)
exit()
# Visualizing npz files
def read_data_files(self, subset='train'):
"""Reads from data file and return images and labels in a numpy array."""
if subset == 'train':
filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i)
for i in xrange(1, 6)]
elif subset == 'validation':
filenames = [os.path.join(self.data_dir, 'test_batch')]
else:
raise ValueError('Invalid data subset "%s"' % subset)
inputs = []
for filename in filenames:
with gfile.Open(filename, 'r') as f:
inputs.append(cPickle.load(f))
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
all_images = np.concatenate(
[each_input['data'] for each_input in inputs]).astype(np.float32)
all_labels = np.concatenate(
[each_input['labels'] for each_input in inputs])
return all_images, all_labels
def sample(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, training=False)
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
ret, hidden = model.sample(sess, chars, vocab, args.n, args.prime,
args.sample)#.encode('utf-8'))
print("Number of characters generated: ", len(ret))
for i in range(len(ret)):
print("Generated character: ", ret[i])
print("Assosciated hidden state:" , hidden[i])
def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""
# load the saved model
classifier = pickle.load(open('best_model.pkl'))
# compile a predictor function
predict_model = theano.function(
inputs=[classifier.input],
outputs=classifier.y_pred)
# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()
predicted_values = predict_model(test_set_x[:10])
print("Predicted values for the first 10 examples in test set:")
print(predicted_values)
def load_npy_to_any(path='', name='file.npy'):
"""Load .npy file.
Examples
---------
- see save_any_to_npy()
"""
file_path = os.path.join(path, name)
try:
npy = np.load(file_path).item()
except:
npy = np.load(file_path)
finally:
try:
return npy
except:
print("[!] Fail to load %s" % file_path)
exit()
## Folder functions
def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""
# load the saved model
classifier = pickle.load(open('best_model.pkl'))
# compile a predictor function
predict_model = theano.function(
inputs=[classifier.input],
outputs=classifier.y_pred)
# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()
predicted_values = predict_model(test_set_x[:10])
print("Predicted values for the first 10 examples in test set:")
print(predicted_values)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to load stored checkpointed models from')
parser.add_argument('-n', type=int, default=200,
help='number of words to sample')
parser.add_argument('--prime', type=str, default=' ',
help='prime text')
parser.add_argument('--pick', type=int, default=1,
help='1 = weighted pick, 2 = beam search pick')
parser.add_argument('--width', type=int, default=4,
help='width of the beam search')
parser.add_argument('--sample', type=int, default=1,
help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces')
args = parser.parse_args()
sample(args)
def load(self, local_dir_=None):
'''
load dataset from local disk
Args:
local_dir_: string or None
if None, will use default Dataset.DEFAULT_DIR
'''
def load(self, local_dir_=None):
if local_dir_ is None:
local_dir = self.DEFAULT_DIR
else:
local_dir = Path(local_dir_)
data_di = np.load(str(local_dir/'cifar10.npz'))
self.datum[:] = data_di['images']
self.labels[:] = data_di['labels']
def install(
self, local_dst_dir_=None, local_src_dir_=None, clean_install_=False):
'''
Install the dataset into directly usable format,
requires downloading for public dataset.
Args:
local_dst_dir_: string or None
where to install the dataset, None -> "%(default_dir)s"
local_src_dir_: string or None
where to find the raw downloaded files, None -> "%(default_dir)s"
'''
local_dst_dir = self.DEFAULT_DIR if local_dst_dir_ is None else Path(local_dst_dir_)
local_src_dir = self.DEFAULT_DIR if local_src_dir_ is None else Path(local_src_dir_)
local_dst_dir.mkdir(parents=True, exist_ok=True)
assert local_src_dir.exists()
images = np.empty((60000,3,32,32), dtype=np.uint8)
labels = np.empty((60000,), dtype=np.uint8)
tarfile_name = str(local_src_dir / 'cifar-10-python.tar.gz')
with tarfile.open(tarfile_name, 'r:gz') as tf:
for i in range(5):
with tf.extractfile('cifar-10-batches-py/data_batch_%d'%(i+1)) as f:
data_di = pickle.load(f, encoding='bytes')
images[(10000*i):(10000*(i+1))] = data_di[b'data'].reshape((10000,3,32,32))
labels[(10000*i):(10000*(i+1))] = np.asarray(data_di[b'labels'], dtype=np.uint8)
with tf.extractfile('cifar-10-batches-py/test_batch') as f:
data_di = pickle.load(f, encoding='bytes')
images[50000:60000] = data_di[b'data'].reshape((10000,3,32,32))
labels[50000:60000] = data_di[b'labels']
np.savez_compressed(str(local_dst_dir / 'cifar10.npz'), images=images, labels=labels)
if clean_install_:
os.remove(tarfile_name)
def load(self, local_dir_=None):
if local_dir_ is None:
local_dir = self.DEFAULT_DIR
else:
local_dir = Path(local_dir_)
data = np.load(str(local_dir / 'mnist.npz'))
self.labels = data['labels']
self.datum = data['images']
self.label_map = np.arange(10)
self.imsize = (1,28,28)
def pickle_load(filename):
"""Deserialize data from file using gzip compression."""
if filename.endswith('.pkz'):
with gzip.open(filename, 'rb') as f:
return pickle.load(f)
elif filename.endswith('.jz'):
with gzip.open(filename, 'rt') as f:
return json_loads(f.read())
else:
raise ValueError(
'Cannot determine format: {}'.format(os.path.basename(filename)))