distributed_mnist.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:tensorflow-basic 作者: weaponsjtu 项目源码 文件源码
def get_device_setter(num_parameter_servers, num_workers):
    """ 
    Get a device setter given number of servers in the cluster.
    Given the numbers of parameter servers and workers, construct a device
    setter object using ClusterSpec.
    Args:
        num_parameter_servers: Number of parameter servers
        num_workers: Number of workers
    Returns:
        Device setter object.
    """

    ps_hosts = re.findall(r'[\w\.:]+', FLAGS.ps_hosts) # split address
    worker_hosts = re.findall(r'[\w\.:]+', FLAGS.worker_hosts) # split address

    assert num_parameter_servers == len(ps_hosts)
    assert num_workers == len(worker_hosts)

    cluster_spec = tf.train.ClusterSpec({"ps":ps_hosts,"worker":worker_hosts})

    # Get device setter from the cluster spec #
    return tf.train.replica_device_setter(cluster=cluster_spec)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号