def _add_record(self, data, label=None, key=None):
data_dims = data.shape
if data.ndim == 1:
data_dims = np.array([data_dims[0], 1, 1], dtype=int)
elif data.ndim == 2:
data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int)
datum = Datum()
datum.channels, datum.height, datum.width = data_dims[0], data_dims[1], data_dims[2]
if data.dtype == np.uint8:
datum.data = data.tostring()
else:
datum.float_data.extend(data.tolist())
datum.label = int(label) if label is not None else -1
key = ('{:08}'.format(self.num) if key is None else key).encode('ascii')
with self.env.begin(write=True) as txn:
txn.put(key, datum.SerializeToString())
self.num += 1
评论列表
文章目录