def get_device_setter(num_parameter_servers, num_workers):
"""
Get a device setter given number of servers in the cluster.
Given the numbers of parameter servers and workers, construct a device
setter object using ClusterSpec.
Args:
num_parameter_servers: Number of parameter servers
num_workers: Number of workers
Returns:
Device setter object.
"""
ps_hosts = re.findall(r'[\w\.:]+', FLAGS.ps_hosts) # split address
worker_hosts = re.findall(r'[\w\.:]+', FLAGS.worker_hosts) # split address
assert num_parameter_servers == len(ps_hosts)
assert num_workers == len(worker_hosts)
cluster_spec = tf.train.ClusterSpec({"ps":ps_hosts,"worker":worker_hosts})
# Get device setter from the cluster spec #
return tf.train.replica_device_setter(cluster=cluster_spec)
评论列表
文章目录