def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
"""Collect IDs of control-state entries for a driver from a shard.
Args:
driver_id: The ID of the driver.
redis_shard_index: The index of the Redis shard to query.
Returns:
Lists of IDs: (returned_object_ids, task_ids, put_objects). The
first two are relevant to the driver and are safe to delete.
The last contains all "put" objects in this redis shard; each
element is an (object_id, corresponding task_id) pair.
"""
# TODO(zongheng): consider adding save & restore functionalities.
redis = self.state.redis_clients[redis_shard_index]
task_table_infos = {} # task id -> TaskInfo messages
# Scan the task table & filter to get the list of tasks belong to this
# driver. Use a cursor in order not to block the redis shards.
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
entry = redis.hgetall(key)
task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
if driver_id != task_info.DriverId():
# Ignore tasks that aren't from this driver.
continue
task_table_infos[task_info.TaskId()] = task_info
# Get the list of objects returned by these tasks. Note these might
# not belong to this redis shard.
returned_object_ids = []
for task_info in task_table_infos.values():
returned_object_ids.extend([
task_info.Returns(i) for i in range(task_info.ReturnsLength())
])
# Also record all the ray.put()'d objects.
put_objects = []
for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
entry = redis.hgetall(key)
if entry[b"is_put"] == "0":
continue
object_id = key.split(OBJECT_INFO_PREFIX)[1]
task_id = entry[b"task"]
put_objects.append((object_id, task_id))
return returned_object_ids, task_table_infos.keys(), put_objects
评论列表
文章目录