python类Datum()的实例源码

caffe_lmdb.py 文件源码 项目:score-zeroshot 作者: pedro-morgado 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def loop_records(self, num_records=0, init_key=None):
        env = lmdb.open(self.fn, readonly=True)
        datum = Datum()
        with env.begin() as txn:
            cursor = txn.cursor()
            if init_key is not None:
                if not cursor.set_key(init_key):
                    raise ValueError('key ' + init_key + ' not found in lmdb ' + self.fn + '.')

            num_read = 0
            for key, value in cursor:
                datum.ParseFromString(value)
                label = datum.label
                data = datum_to_array(datum).squeeze()
                yield (data, label, key)
                num_read += 1
                if num_records != 0 and num_read == num_records:
                    break
        env.close()
caffe_lmdb.py 文件源码 项目:score-zeroshot 作者: pedro-morgado 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _add_record(self, data, label=None, key=None):
        data_dims = data.shape
        if data.ndim == 1:
            data_dims = np.array([data_dims[0], 1, 1], dtype=int)
        elif data.ndim == 2:
            data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int)

        datum = Datum()
        datum.channels, datum.height, datum.width = data_dims[0], data_dims[1], data_dims[2]
        if data.dtype == np.uint8:
            datum.data = data.tostring()
        else:
            datum.float_data.extend(data.tolist())
        datum.label = int(label) if label is not None else -1

        key = ('{:08}'.format(self.num) if key is None else key).encode('ascii')
        with self.env.begin(write=True) as txn:
            txn.put(key, datum.SerializeToString())
        self.num += 1
convert_shoes7k_data.py 文件源码 项目:fast-image-retrieval 作者: xueeinstein 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def save_to_lmdb(images, labels, lmdb_file):
    if not os.path.exists(lmdb_file):
        batch_size = 256
        lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12))
        lmdb_txn = lmdb_env.begin(write=True)
        item_id = 0
        datum = caffe_pb2.Datum()

        for i in range(images.shape[0]):
            im = cv2.imread(images[i])
            im = cv2.resize(im, (IM_HEIGHT, IM_WIDTH))
            datum.channels = im.shape[2]
            datum.height = im.shape[0]
            datum.width = im.shape[1]
            datum.data = im.tobytes()
            datum.label = labels[i]
            keystr = '{:0>8d}'.format(item_id)
            lmdb_txn.put(keystr, datum.SerializeToString())

            # write batch
            if (item_id + 1) % batch_size == 0:
                lmdb_txn.commit()
                lmdb_txn = lmdb_env.begin(write=True)
                print('converted {} images'.format(item_id + 1))

            item_id += 1

        # write last batch
        if (item_id + 1) % batch_size != 0:
            lmdb_txn.commit()
            print('converted {} images'.format(item_id + 1))
            print('Generated ' + lmdb_file)
    else:
        print(lmdb_file + ' already exists')
convert_facescrub_data.py 文件源码 项目:fast-image-retrieval 作者: xueeinstein 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def save_to_lmdb(images, labels, lmdb_file):
    if not os.path.exists(lmdb_file):
        batch_size = 256
        lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12))
        lmdb_txn = lmdb_env.begin(write=True)
        item_id = 0
        datum = caffe_pb2.Datum()

        for i in range(images.shape[0]):
            im = cv2.imread(images[i])
            if im is None:
                continue
            im = cv2.resize(im, (IM_HEIGHT, IM_WIDTH))
            datum.channels = im.shape[2]
            datum.height = im.shape[0]
            datum.width = im.shape[1]
            datum.data = im.tobytes()
            datum.label = labels[i]
            keystr = '{:0>8d}'.format(item_id)
            lmdb_txn.put(keystr, datum.SerializeToString())

            # write batch
            if (item_id + 1) % batch_size == 0:
                lmdb_txn.commit()
                lmdb_txn = lmdb_env.begin(write=True)
                print('converted {} images'.format(item_id + 1))

            item_id += 1

        # write last batch
        if (item_id + 1) % batch_size != 0:
            lmdb_txn.commit()
            print('converted {} images'.format(item_id + 1))
            print('Generated ' + lmdb_file)
    else:
        print(lmdb_file + ' already exists')
create_lmdb.py 文件源码 项目:deeplearning-cats-dogs-tutorial 作者: adilmoujahid 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def make_datum(img, label):
    #image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=3,
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=np.rollaxis(img, 2).tostring())
GenImageLmdb.py 文件源码 项目:Market1501-CVLab 作者: Lizw14 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def make_datum(img, label):  
    #image is numpy.ndarray format. BGR instead of RGB  
    return caffe_pb2.Datum(  
        channels=3,  
        width=IMAGE_WIDTH,  
        height=IMAGE_HEIGHT,  
        label=label,
        data=np.transpose(img, (2, 0, 1)).tostring()) 
        # or .tobytes() if numpy < 1.9

# key = 0
# env = lmdb.open(img_lmdb_path, map_size=int(1e12))
# with env.begin(write=True) as txn:
#     for idx in xrange(numSample):
#         info = data[idx].split(" ")
#         OriImg = cv2.imread(datadir + info[0])
#         img = cv2.resize(OriImg,(IMAGE_WIDTH,IMAGE_HEIGHT))
#         label = int(info[1])
#         img = np.transpose(img, (2, 0, 1))
#         datum = caffe.io.array_to_datum(img, label)
#         key_str = '{:08}'.format(key)
# #        txn.put(key_str.encode('ascii'), datum.SerializeToString())
#         txn.put(key_str, datum.SerializeToString())
#         key += 1
#     for idx in xrange(numSample):
#         info = data[idx].split(" ")
#         OriImg = cv2.imread(datadir + info[0])
#         img = cv2.resize(OriImg,(IMAGE_WIDTH,IMAGE_HEIGHT))
#         label = int(info[1])
#         img = cv2.flip(img,1)
#         img = np.transpose(img, (2, 0, 1))
#         datum = caffe.io.array_to_datum(img, label)
#         key_str = '{:08}'.format(key)
# #        txn.put(key_str.encode('ascii'), datum.SerializeToString())
#         txn.put(key_str, datum.SerializeToString())
#         key += 1
# print key
convert_lmdb_to_numpy.py 文件源码 项目:SpindleNet 作者: yokattame 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def main(args):
  datum = Datum()
  data = []
  env = lmdb.open(args.input_lmdb)
  with env.begin() as txn:
    cursor = txn.cursor()
    for i, (key, value) in enumerate(cursor):
      if i >= args.truncate:
        break
      datum.ParseFromString(value)
      data.append(datum.float_data)
  data = np.squeeze(np.asarray(data))
  np.save(args.output_npy, data)
lmbd_creation.py 文件源码 项目:GitImpact 作者: ludovicdmt 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def make_datum(img, label):
    #image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=1, # images are in black and white 
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=img.tostring())


问题


面经


文章

微信
公众号

扫码关注公众号