def get_rpc_port_by_rank(self, rank, num_servers):
if self.mpirun_proc is None:
raise RuntimeError("Launch mpirun_proc before reading of rpc ports")
if self._rpc_ports is not None:
return self._rpc_ports[rank]
server_info_pattern = re.compile("^" + LINE_TOKEN +
":([\d]+):([\d]+):([\d]+):" +
LINE_TOKEN + "$")
self._tmpfile.seek(0)
while True:
fcntl.lockf(self._tmpfile, fcntl.LOCK_SH)
line_count = sum(1 for line in self._tmpfile if server_info_pattern.match(line))
self._tmpfile.seek(0)
fcntl.lockf(self._tmpfile, fcntl.LOCK_UN)
if line_count == num_servers:
break
else:
time.sleep(0.1)
server_infos = [tuple([int(server_info_pattern.match(line).group(1)),
int(server_info_pattern.match(line).group(3))])
for line in self._tmpfile]
server_infos = sorted(server_infos, key=lambda x: x[0])
self._rpc_ports = [row[1] for row in server_infos]
logger.debug("get_rpc_ports: ports (in MPI rank order): %s", self._rpc_ports)
self._tmpfile.close()
return self._rpc_ports[rank]
评论列表
文章目录