def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
"""
Asynchronous version of train
See Also
--------
train
"""
# Break apart Dask.array/dataframe into chunks/parts
data_parts = data.to_delayed()
label_parts = labels.to_delayed()
if isinstance(data_parts, np.ndarray):
assert data_parts.shape[1] == 1
data_parts = data_parts.flatten().tolist()
if isinstance(label_parts, np.ndarray):
assert label_parts.ndim == 1 or label_parts.shape[1] == 1
label_parts = label_parts.flatten().tolist()
# Arrange parts into pairs. This enforces co-locality
parts = list(map(delayed, zip(data_parts, label_parts)))
parts = client.compute(parts) # Start computation in the background
yield _wait(parts)
# Because XGBoost-python doesn't yet allow iterative training, we need to
# find the locations of all chunks and map them to particular Dask workers
key_to_part_dict = dict([(part.key, part) for part in parts])
who_has = yield client.scheduler.who_has(keys=[part.key for part in parts])
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[first(workers)].append(key_to_part_dict[key])
ncores = yield client.scheduler.ncores() # Number of cores per worker
# Start the XGBoost tracker on the Dask scheduler
host, port = parse_host_port(client.scheduler.address)
env = yield client._run_on_scheduler(start_tracker,
host.strip('/:'),
len(worker_map))
# Tell each worker to train on the chunks/parts that it has locally
futures = [client.submit(train_part, env,
assoc(params, 'nthread', ncores[worker]),
list_of_parts, workers=worker,
dmatrix_kwargs=dmatrix_kwargs, **kwargs)
for worker, list_of_parts in worker_map.items()]
# Get the results, only one will be non-None
results = yield client._gather(futures)
result = [v for v in results if v][0]
raise gen.Return(result)
评论列表
文章目录