core.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:dask-xgboost 作者: dask 项目源码 文件源码
def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
    """
    Asynchronous version of train

    See Also
    --------
    train
    """
    # Break apart Dask.array/dataframe into chunks/parts
    data_parts = data.to_delayed()
    label_parts = labels.to_delayed()
    if isinstance(data_parts, np.ndarray):
        assert data_parts.shape[1] == 1
        data_parts = data_parts.flatten().tolist()
    if isinstance(label_parts, np.ndarray):
        assert label_parts.ndim == 1 or label_parts.shape[1] == 1
        label_parts = label_parts.flatten().tolist()

    # Arrange parts into pairs.  This enforces co-locality
    parts = list(map(delayed, zip(data_parts, label_parts)))
    parts = client.compute(parts)  # Start computation in the background
    yield _wait(parts)

    # Because XGBoost-python doesn't yet allow iterative training, we need to
    # find the locations of all chunks and map them to particular Dask workers
    key_to_part_dict = dict([(part.key, part) for part in parts])
    who_has = yield client.scheduler.who_has(keys=[part.key for part in parts])
    worker_map = defaultdict(list)
    for key, workers in who_has.items():
        worker_map[first(workers)].append(key_to_part_dict[key])

    ncores = yield client.scheduler.ncores()  # Number of cores per worker

    # Start the XGBoost tracker on the Dask scheduler
    host, port = parse_host_port(client.scheduler.address)
    env = yield client._run_on_scheduler(start_tracker,
                                         host.strip('/:'),
                                         len(worker_map))

    # Tell each worker to train on the chunks/parts that it has locally
    futures = [client.submit(train_part, env,
                             assoc(params, 'nthread', ncores[worker]),
                             list_of_parts, workers=worker,
                             dmatrix_kwargs=dmatrix_kwargs, **kwargs)
               for worker, list_of_parts in worker_map.items()]

    # Get the results, only one will be non-None
    results = yield client._gather(futures)
    result = [v for v in results if v][0]
    raise gen.Return(result)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号