python类open()的实例源码

pyxis.py 文件源码 项目:ml-pyxis 作者: vicolab 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def __init__(self, dirpath, map_size_limit, ram_gb_limit=2):
        self.map_size_limit = int(map_size_limit)  # Megabytes (MB)
        self.ram_gb_limit = float(ram_gb_limit)  # Gigabytes (GB)
        self.keys = []
        self.nb_samples = 0

        # Minor sanity checks
        if self.map_size_limit <= 0:
            raise ValueError('The LMDB map size must be positive: '
                             '{}'.format(self.map_size_limit))
        if self.ram_gb_limit <= 0:
            raise ValueError('The RAM limit (GB) per write must be '
                             'positive: {}'.format(self.ram_gb_limit))

        # Convert `map_size_limit` from MB to B
        map_size_limit <<= 20

        # Open LMDB environment
        self._lmdb_env = lmdb.open(dirpath,
                                   map_size=map_size_limit,
                                   max_dbs=NB_DBS)

        # Open the default database(s) associated with the environment
        self.data_db = self._lmdb_env.open_db(DATA_DB)
        self.meta_db = self._lmdb_env.open_db(META_DB)
dataset.py 文件源码 项目:crnn 作者: wulivicte 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform
dataset.py 文件源码 项目:sceneReco 作者: bear63 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform
dataset.py 文件源码 项目:sceneReco 作者: bear63 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))
            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
gen_lmdb.py 文件源码 项目:FCN-VOC2012-Training-Config 作者: voidrank 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def gen_input(lmdbname, file_list):
    X = np.zeros((len(file_list), 3, HEIGHT, WIDTH), dtype=np.float32)
    map_size = X.nbytes * 5

    env = lmdb.open(lmdbname, map_size=map_size)

    count = 0
    for i in file_list:
        print count
        with env.begin(write=True) as txn:
            filename = os.path.join(DIR, "JPEGImages", i + ".jpg")
            m = np.asarray(Image.open(filename)).transpose((2, 0, 1))
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = m.shape[0]
            datum.height = m.shape[1]
            datum.width = m.shape[2]
            datum.data = m.tobytes()
            str_id = i
            txn.put(str_id.encode("ascii"), datum.SerializeToString())
            count += 1
gen_lmdb.py 文件源码 项目:FCN-VOC2012-Training-Config 作者: voidrank 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def gen_output(lmdbname, file_list):
    X = np.zeros((len(file_list), 1, HEIGHT, WIDTH), dtype=np.uint8)
    map_size = X.nbytes * 3

    env = lmdb.open(lmdbname, map_size=map_size)

    count = 0
    for i in file_list:
        print count
        with env.begin(write=True) as txn:
            filename = os.path.join(DIR, "SegmentationClass", i + ".png")
            m = deepcopy(np.asarray(Image.open(filename)))
            for x in range(m.shape[0]):
                for y in range(m.shape[1]):
                    if m[x][y] == 255:
                        m[x][y] = 0
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = 1
            datum.height = m.shape[0]
            datum.width = m.shape[1]
            datum.data = m.tobytes()
            str_id = i
            txn.put(str_id.encode("ascii"), datum.SerializeToString())
            count += 1
data2lmdb.py 文件源码 项目:train-CRF-RNN 作者: martinkersner 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def split_train_test_imgs(class_names, test_ratio):
  train_imgs = []
  test_imgs = []

  for i in class_names:
    file_name = i + '.txt' 
    num_lines = get_num_lines(file_name)
    num_test_imgs = test_ratio * num_lines
    current_line = 1

    with open(file_name, 'rb') as f:
      for line in f:
        if current_line < num_test_imgs:
          test_imgs.append(line.strip())
        else:
          train_imgs.append(line.strip())

        current_line += 1

  print(str(len(train_imgs)) + ' train images')
  print(str(len(test_imgs)) + ' test images')

  return train_imgs, test_imgs
data2lmdb.py 文件源码 项目:train-CRF-RNN 作者: martinkersner 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def convert2lmdb(path_src, src_imgs, ext, path_dst, class_ids, preprocess_mode, im_sz, data_mode):
  if os.path.isdir(path_dst):
    print('DB ' + path_dst + ' already exists.\n'
          'Skip creating ' + path_dst + '.', file=sys.stderr)
    return None

  if data_mode == 'label':
    lut = create_lut(class_ids)

  db = lmdb.open(path_dst, map_size=int(1e12))

  with db.begin(write=True) as in_txn:
    for idx, img_name in enumerate(src_imgs):
      #img = imread(os.path.join(path_src + img_name)+ext)
      img = np.array(Image.open(os.path.join(path_src + img_name)+ext))
      img = img.astype(np.uint8)

      if data_mode == 'label':
        img = preprocess_label(img, lut, preprocess_mode, im_sz)
      elif data_mode == 'image':
        img = preprocess_image(img, preprocess_mode, im_sz)

      img_dat = caffe.io.array_to_datum(img)
      in_txn.put('{:0>10d}'.format(idx), img_dat.SerializeToString())
caffe_lmdb.py 文件源码 项目:score-zeroshot 作者: pedro-morgado 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def loop_records(self, num_records=0, init_key=None):
        env = lmdb.open(self.fn, readonly=True)
        datum = Datum()
        with env.begin() as txn:
            cursor = txn.cursor()
            if init_key is not None:
                if not cursor.set_key(init_key):
                    raise ValueError('key ' + init_key + ' not found in lmdb ' + self.fn + '.')

            num_read = 0
            for key, value in cursor:
                datum.ParseFromString(value)
                label = datum.label
                data = datum_to_array(datum).squeeze()
                yield (data, label, key)
                num_read += 1
                if num_records != 0 and num_read == num_records:
                    break
        env.close()
mc_tasks.py 文件源码 项目:mediachain-indexer 作者: mediachain 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def verify_img(buf):
    """
    Verify image.
    """

    sbuf = StringIO(buf)

    try:
        ## Basic check:
        img = Image.open(sbuf)
        img.verify()

        ## Detect truncated:
        img = Image.open(sbuf)
        img.load()
    except KeyboardInterrupt:
        raise
    except:
        print ('VERIFY_IMG_FAILED', buf[:100])
        return False
    return img
mc_ingest.py 文件源码 项目:mediachain-indexer 作者: mediachain 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def verify_img(buf):
    """
    Verify image.
    """
    from PIL import Image
    from cStringIO import StringIO

    sbuf = StringIO(buf)

    try:
        ## Basic check:
        img = Image.open(sbuf)
        img.verify()

        ## Detect truncated:
        img = Image.open(sbuf)
        img.load()
    except KeyboardInterrupt:
        raise
    except:
        return False
    return True
data_process.py 文件源码 项目:pytorch_crowd_count 作者: BingzheWu 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def process_dump_tohdf5data(X,Y, path, phase):

    batch_size = 7000
    X_process = np.zeros((batch_size, 3, patch_h, patch_w), dtype = np.float32)
    Y_process = np.zeros((batch_size, net_density_h, net_density_w), dtype = np.float32)
    with open(os.path.join(path, phase+'.txt'), 'w') as f:
        i1 = 0
        while i1 < len(X):
            if i1+batch_size < len(X):
                i2 = i1 + batch_size
            else:
                i2 = len(X)
            file_name = os.path.join(path, phase+'_'+str(i1)+'.h5')
            with h5py.File(file_name, 'w') as hf:
                for j, img in enumerate(X[i1:i2]):
                    X_process[j] = img.copy().transpose(2,0,1).astype(np.float32)
                    Y_process[j] = density_resize(Y[i1+j], fx = float(net_density_w)/patch_w, fy = float(net_density_h) / patch_h)
                hf['data'] = X_process[:(i2-i1)]
                hf['label'] = Y_process[:(i2-i1)]
            f.write(file_name+'\n')
            i1 += batch_size
data_process.py 文件源码 项目:pytorch_crowd_count 作者: BingzheWu 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def read_lmdb(lmdb_path):
    env = lmdb.open(lmdb_path)
    with env.begin() as txn:
        cursor = txn.cursor()
        for (idx, (key, value)) in enumerate(cursor):
            image = np.fromstring(value, dtype = np.float32)
            #image = np.reshape(image, (3,225,225))/255.0
            image = np.reshape(image, (27, 27))
            #image = image.transpose((1,2,0))
            print(image)
            plt.imshow(image, cmap = 'hot')
            plt.show()
            break
        #image = txn.get('0')
        #image = np.fromstring(image)[0]
        #print image.shape
loader.py 文件源码 项目:pytorch_crowd_count 作者: BingzheWu 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def __init__(self, lmdb_image_datapath, lmdb_label_datapath):
                super(UCF_CC_50, self).__init__()
        self.lmdb_image_datapath = lmdb_image_datapath
        self.lmdb_label_datapath = lmdb_label_datapath
        self.images = []
        self.gts = []
        self.total_patches = 0
        self.limits = []
        self.num_files = 0
        self.file_list = []
        self.env_image = lmdb.open(self.lmdb_image_datapath)
        self.env_label = lmdb.open(self.lmdb_label_datapath)
        self.txn_image = self.env_image.begin()
        self.txn_label = self.env_label.begin()
        self.cursor_image = iter(self.txn_image.cursor())
        self.cursor_label = self.txn_label.cursor()
caffe_class_utils.py 文件源码 项目:Triplet_Loss_SBIR 作者: TuBui 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def read(self, in_path):
    """
    read lmdb, return image data and label
    """
    print 'Reading ' + in_path
    env = lmdb.open(in_path, readonly=True)
    N = env.stat()['entries']
    txn = env.begin()
    for i in range(N):
      str_id = '{:08}'.format(i)
      raw_datum = txn.get(str_id)
      datum = caffe.proto.caffe_pb2.Datum()
      datum.ParseFromString(raw_datum)
      feature = caffe.io.datum_to_array(datum)
      if i==0:
        data = np.zeros((N,feature.shape[0],feature.shape[1],
                         feature.shape[2]),dtype=np.uint8)
        label = np.zeros(N,dtype=np.int64)
      data[i] = feature
      label[i] = datum.label
    env.close()
    return data, label
lmdb_utils.py 文件源码 项目:pytorch-yolo2 作者: marvis 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def __init__(self, lmdb_root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0):
        self.env = lmdb.open(lmdb_root,
                 max_readers=1,
                 readonly=True,
                 lock=False,
                 readahead=False,
                 meminit=False)
        self.txn = self.env.begin(write=False) 
        self.nSamples = int(self.txn.get('num-samples'))
        self.indices = range(self.nSamples) 
        if shuffle:
            random.shuffle(self.indices)

        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.shape = shape
        self.seen = seen
        #if self.train:
        #    print('init seen to %d' % (self.seen))
lmdb_detect.py 文件源码 项目:DeepID2 作者: chenzeyuczy 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def detect_lmdb(path):
    env = lmdb.open(path, readonly=False)
    print "Info of lmdb at", path
    for key, value in env.stat().items():
        print key, ":", value

    datum = caffe.proto.caffe_pb2.Datum()
    with env.begin() as txn:
        cursor = txn.cursor()
        cursor.next()
        key, value = cursor.key(), cursor.value()
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        print "Data shape:", data.shape
    env.close()
create_lmdb_kitti_dataset.py 文件源码 项目:sun-bcnn 作者: utiasSTARS 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def readGroundTruth(datasetTxtFilepath):
    sunDirList = []
    imageFileNames = []
    with open(datasetTxtFilepath) as f:
        for line in f:
            lineItems = line.split()
            fname = lineItems[0]

            sunDir = lineItems[1:4]
            sunDir = [float(i) for i in sunDir]

            if azZenTarget:
                sunAzZen = [0, 0]
                sunAzZen[0] = math.degrees(math.atan2(sunDir[0], sunDir[2]))
                sunAzZen[1] = math.degrees(math.acos(-sunDir[1]))
                sunDirList.append(sunAzZen)
            else:
                sunDirList.append(sunDir)
            imageFileNames.append(fname)

    return sunDirList, imageFileNames
gen_lmdb.py 文件源码 项目:brain-tumor 作者: voidrank 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def make_lmdb_input(lmdbname, channel_directories, range_set):

    X = np.zeros((len(range_set), len(channel_directories), WIDTH, HEIGHT), dtype=np.double)
    map_size = X.nbytes * 10

    env = lmdb.open(lmdbname, map_size=map_size)

    count = 0
    for i in range_set:
        with env.begin(write=True) as txn:
            filename = str(i) + ".png"
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = X.shape[1]
            datum.height = X.shape[2]
            datum.width = X.shape[3]
            for j in range(len(channel_directories)):
                dirname = channel_directories[j]
                X[count][j] = np.asarray(PIL.Image.open(os.path.join(dirname, filename)), dtype=np.double)
            datum.data = X[count].tobytes()
            str_id = '{:08}'.format(count)
            txn.put(str_id.encode("ascii"), datum.SerializeToString())
            count += 1
lmdb_access.py 文件源码 项目:caffe-materials 作者: kyehyeon 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def write_lmdb(db_path, list_filename, height, width):
  map_size = 9999999999
  db = lmdb.open(db_path, map_size=map_size)
  writer = db.begin(write=True)
  datum = caffe.proto.caffe_pb2.Datum()
  for index, line in enumerate(open(list_filename, 'r')):
    img_filename, label = line.strip().split(' ')
    img = cv2.imread(img_filename, 1)
    img = cv2.resize(img, (height, width))
    _, img_jpg = cv2.imencode('.jpg', img)
    datum.channels = 3
    datum.height = height
    datum.width = width
    datum.label = int(label)
    datum.encoded = True
    datum.data = img_jpg.tostring()
    datum_byte = datum.SerializeToString()
    index_byte = '%010d' % index
    writer.put(index_byte, datum_byte, append=True)
  writer.commit()
  db.close()
batch_manager.py 文件源码 项目:SSD-Keras_Tensorflow 作者: jedol 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def __init__(self, source, batch_size, shuffle=True, use_prefetch=True, capacity=32):
        ## open LMDB
        self.env = lmdb.open(source, readonly=True)
        self.txn = self.env.begin()
        self.cur = self.txn.cursor()

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.use_prefetch = use_prefetch
        self.capacity = capacity
        self.num_data = int(self.txn.stat()['entries'])

        self.reset_inds()

        if self.use_prefetch:
            self.batch_queue = Queue(capacity)
            self.proc = Process(target=self._worker)
            self.proc.start()
            def cleanup():
                self.proc.terminate()
                self.proc.join()
            import atexit
            atexit.register(cleanup)
mrbi_to_lmdb.py 文件源码 项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def make_test():
    print 'Loading Matlab data.'
    f = '/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mnist_rotation_back_image_new/mnist_all_background_images_rotation_normalized_test.amat'

    # name of your matlab variables:

    X,Y=get_data(f)
    N = Y.shape[0]
    map_size = X.nbytes*2
    #if you want to shuffle your data
    #random.shuffle(N)
    env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mrbi_test', map_size=map_size)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(N):
            im_dat = caffe.io.array_to_datum(X[i],Y[i])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
mrbi_to_lmdb.py 文件源码 项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def view_lmdb_data():
    lmdb_env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_train/')
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()
    x=[]
    y=[]

    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #plt.imshow(np.rollaxis(data,0,3))
        x.append(data)
        y.append(label)
    print len(y)
mat_to_lmdb.py 文件源码 项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def view_lmdb_data():
    lmdb_env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_train/')
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()
    x=[]
    y=[]

    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #plt.imshow(np.rollaxis(data,0,3))
        x.append(data)
        y.append(label)
    print len(y)
misc.py 文件源码 项目:pytorch-playground 作者: aaron-xichen 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def load_lmdb(lmdb_file, n_records=None):
    import lmdb
    import numpy as np
    lmdb_file = expand_user(lmdb_file)
    if os.path.exists(lmdb_file):
        data = []
        env = lmdb.open(lmdb_file, readonly=True, max_readers=512)
        with env.begin() as txn:
            cursor = txn.cursor()
            begin_st = time.time()
            print("Loading lmdb file {} into memory".format(lmdb_file))
            for key, value in cursor:
                _, target, _ = key.decode('ascii').split(':')
                target = int(target)
                img = cv2.imdecode(np.fromstring(value, np.uint8), cv2.IMREAD_COLOR)
                data.append((img, target))
                if n_records is not None and len(data) >= n_records:
                    break
        env.close()
        print("=> Done ({:.4f} s)".format(time.time() - begin_st))
        return data
    else:
        print("Not found lmdb file".format(lmdb_file))
dataset.py 文件源码 项目:sceneReco 作者: yijiuzai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform
dataset.py 文件源码 项目:sceneReco 作者: yijiuzai 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))
            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
lmdb_creator.py 文件源码 项目:phocnet 作者: ssudholt 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def open_single_lmdb_for_write(self, lmdb_path, max_lmdb_size=1024**4, create=True, label_map=None):
        '''
        Opens a single LMDB for inserting ndarrays (i.e. images)

        Args:
            lmdb_path (str): Where to save the LMDB
            max_lmdb_size (int): The maximum size in bytes of the LMDB (default: 1TB)
            create (bool):  If this flag is set, a potentially previously created LMDB at lmdb_path
                            is deleted and overwritten by this new LMDB
            label_map (dictionary): If you supply a dictionary mapping string labels to integer indices, you can later
                                    call put_single with string labels instead of int labels
        '''
        # delete existing LMDB if necessary
        if os.path.exists(lmdb_path) and create:
            self.logger.debug('Erasing previously created LMDB at %s', lmdb_path)
            shutil.rmtree(lmdb_path)
        self.logger.info('Opening single LMDB at %s for writing', lmdb_path)
        self.database_images = lmdb.open(path=lmdb_path, map_size=max_lmdb_size)
        self.txn_images = self.database_images.begin(write=True)
        self.label_map = label_map
format.py 文件源码 项目:ternarynet 作者: czhu95 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def __init__(self, lmdb_dir, shuffle=True):
        self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
                map_size=1099511627776 * 2, max_readers=100)
        self._txn = self._lmdb.begin()
        self._shuffle = shuffle
        self._size = self._txn.stat()['entries']
        if shuffle:
            self.keys = self._txn.get('__keys__')
            if not self.keys:
                self.keys = []
                with timed_operation("Loading LMDB keys ...", log_start=True), \
                        tqdm(total=self._size, ascii=True) as pbar:
                    for k in self._txn.cursor():
                        if k != '__keys__':
                            self.keys.append(k)
                            pbar.update()
create_lmdb_data.py 文件源码 项目:visimportance 作者: cvzoya 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def load_label(maindir, idx, split):
    """
    Load label image as 1 x height x width integer array of label indices.
    The leading singleton dimension is required by the loss.
    """

    if split=='train':
        im = Image.open('{}/GDI/gd_imp_train/{}.png'.format(maindir, idx)) 
    else:
        im = Image.open('{}/GDI/gd_imp_val/{}.png'.format(maindir, idx))

    label = np.array(im, dtype=np.uint8) 
    label = label/255.0

    label = label[np.newaxis, ...]
    return label


问题


面经


文章

微信
公众号

扫码关注公众号