def loop_records(self, num_records=0, init_key=None):
env = lmdb.open(self.fn, readonly=True)
datum = Datum()
with env.begin() as txn:
cursor = txn.cursor()
if init_key is not None:
if not cursor.set_key(init_key):
raise ValueError('key ' + init_key + ' not found in lmdb ' + self.fn + '.')
num_read = 0
for key, value in cursor:
datum.ParseFromString(value)
label = datum.label
data = datum_to_array(datum).squeeze()
yield (data, label, key)
num_read += 1
if num_records != 0 and num_read == num_records:
break
env.close()
python类Datum()的实例源码
def _add_record(self, data, label=None, key=None):
data_dims = data.shape
if data.ndim == 1:
data_dims = np.array([data_dims[0], 1, 1], dtype=int)
elif data.ndim == 2:
data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int)
datum = Datum()
datum.channels, datum.height, datum.width = data_dims[0], data_dims[1], data_dims[2]
if data.dtype == np.uint8:
datum.data = data.tostring()
else:
datum.float_data.extend(data.tolist())
datum.label = int(label) if label is not None else -1
key = ('{:08}'.format(self.num) if key is None else key).encode('ascii')
with self.env.begin(write=True) as txn:
txn.put(key, datum.SerializeToString())
self.num += 1
def save_to_lmdb(images, labels, lmdb_file):
if not os.path.exists(lmdb_file):
batch_size = 256
lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12))
lmdb_txn = lmdb_env.begin(write=True)
item_id = 0
datum = caffe_pb2.Datum()
for i in range(images.shape[0]):
im = cv2.imread(images[i])
im = cv2.resize(im, (IM_HEIGHT, IM_WIDTH))
datum.channels = im.shape[2]
datum.height = im.shape[0]
datum.width = im.shape[1]
datum.data = im.tobytes()
datum.label = labels[i]
keystr = '{:0>8d}'.format(item_id)
lmdb_txn.put(keystr, datum.SerializeToString())
# write batch
if (item_id + 1) % batch_size == 0:
lmdb_txn.commit()
lmdb_txn = lmdb_env.begin(write=True)
print('converted {} images'.format(item_id + 1))
item_id += 1
# write last batch
if (item_id + 1) % batch_size != 0:
lmdb_txn.commit()
print('converted {} images'.format(item_id + 1))
print('Generated ' + lmdb_file)
else:
print(lmdb_file + ' already exists')
convert_facescrub_data.py 文件源码
项目:fast-image-retrieval
作者: xueeinstein
项目源码
文件源码
阅读 20
收藏 0
点赞 0
评论 0
def save_to_lmdb(images, labels, lmdb_file):
if not os.path.exists(lmdb_file):
batch_size = 256
lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12))
lmdb_txn = lmdb_env.begin(write=True)
item_id = 0
datum = caffe_pb2.Datum()
for i in range(images.shape[0]):
im = cv2.imread(images[i])
if im is None:
continue
im = cv2.resize(im, (IM_HEIGHT, IM_WIDTH))
datum.channels = im.shape[2]
datum.height = im.shape[0]
datum.width = im.shape[1]
datum.data = im.tobytes()
datum.label = labels[i]
keystr = '{:0>8d}'.format(item_id)
lmdb_txn.put(keystr, datum.SerializeToString())
# write batch
if (item_id + 1) % batch_size == 0:
lmdb_txn.commit()
lmdb_txn = lmdb_env.begin(write=True)
print('converted {} images'.format(item_id + 1))
item_id += 1
# write last batch
if (item_id + 1) % batch_size != 0:
lmdb_txn.commit()
print('converted {} images'.format(item_id + 1))
print('Generated ' + lmdb_file)
else:
print(lmdb_file + ' already exists')
create_lmdb.py 文件源码
项目:deeplearning-cats-dogs-tutorial
作者: adilmoujahid
项目源码
文件源码
阅读 20
收藏 0
点赞 0
评论 0
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=3,
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=np.rollaxis(img, 2).tostring())
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=3,
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=np.transpose(img, (2, 0, 1)).tostring())
# or .tobytes() if numpy < 1.9
# key = 0
# env = lmdb.open(img_lmdb_path, map_size=int(1e12))
# with env.begin(write=True) as txn:
# for idx in xrange(numSample):
# info = data[idx].split(" ")
# OriImg = cv2.imread(datadir + info[0])
# img = cv2.resize(OriImg,(IMAGE_WIDTH,IMAGE_HEIGHT))
# label = int(info[1])
# img = np.transpose(img, (2, 0, 1))
# datum = caffe.io.array_to_datum(img, label)
# key_str = '{:08}'.format(key)
# # txn.put(key_str.encode('ascii'), datum.SerializeToString())
# txn.put(key_str, datum.SerializeToString())
# key += 1
# for idx in xrange(numSample):
# info = data[idx].split(" ")
# OriImg = cv2.imread(datadir + info[0])
# img = cv2.resize(OriImg,(IMAGE_WIDTH,IMAGE_HEIGHT))
# label = int(info[1])
# img = cv2.flip(img,1)
# img = np.transpose(img, (2, 0, 1))
# datum = caffe.io.array_to_datum(img, label)
# key_str = '{:08}'.format(key)
# # txn.put(key_str.encode('ascii'), datum.SerializeToString())
# txn.put(key_str, datum.SerializeToString())
# key += 1
# print key
def main(args):
datum = Datum()
data = []
env = lmdb.open(args.input_lmdb)
with env.begin() as txn:
cursor = txn.cursor()
for i, (key, value) in enumerate(cursor):
if i >= args.truncate:
break
datum.ParseFromString(value)
data.append(datum.float_data)
data = np.squeeze(np.asarray(data))
np.save(args.output_npy, data)
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=1, # images are in black and white
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=img.tostring())