def __init__(self, data_desc, dtype=None,
batch_filter=None, batch_mode='batch',
ncpu=1, buffer_size=8, hwm=86,
mpi_backend='python'):
super(Feeder, self).__init__(data=as_tuple(data_desc, t=DataDescriptor),
read_only=True)
# find intersection of all indices in DataDescriptor
self._indices_keys = async(
lambda: np.array(
list(set.intersection(*[set(dat.indices.keys())
for dat in self._data])),
dtype=str)
)()
# ====== desire dtype ====== #
nb_data = sum(len(dat._data) for dat in self._data)
self._output_dtype = as_tuple(dtype, N=nb_data)
# ====== Set default recipes ====== #
self._recipes = RecipeList()
self._recipes.set_feeder_info(nb_desc=len(self._data))
self.set_multiprocessing(ncpu, buffer_size, hwm, mpi_backend)
# ====== cache shape information ====== #
# store first dimension
self._cache_shape = None
# if the recipes changed the shape need to be recalculated
self._recipes_changed = False
# ====== Iteration information ====== #
self._running_iter = []
# ====== batch mode ====== #
if batch_filter is None:
batch_filter = _dummy_batch_filter
elif not hasattr(batch_filter, '__call__'):
raise ValueError('batch_filter must be a function has 1 or 2 '
'parameters (X) or (X, y).')
# check if batch_filter Picklable
try:
cPickle.dumps(batch_filter, protocol=2)
except Exception:
raise ValueError("`batch_filter` must be pickle-able, which must be "
"top-level function.")
self._batch_filter = batch_filter
# check batch_mode
batch_mode = str(batch_mode).lower()
if batch_mode not in ("batch", 'file'):
raise ValueError("Only support `batch_mode`: 'file'; 'batch', but "
"given value: '%s'" % batch_mode)
self._batch_mode = batch_mode
# ==================== pickling ==================== #
评论列表
文章目录