def batch_tensors(cls, tensor_list: List[DataArray]) -> DataArray: # type: ignore
"""
Takes the output of ``Field.as_tensor()`` from a list of ``Instances`` and merges it into
one batched tensor for this ``Field``. The default implementation here in the base class
handles cases where ``as_tensor`` returns a single torch tensor per instance, or a
dictionary of single tensors. If your subclass returns something other than this, you need
to override this method.
"""
if isinstance(tensor_list[0], dict):
# This is creating a dict of {token_indexer_key: batch_tensor} for each
# token indexer used to index this field. This is mostly utilised by TextFields.
token_indexer_key_to_batch_dict: Dict[str, List[torch.Tensor]] = defaultdict(list)
for encoding_name_dict in tensor_list:
for indexer_name, tensor in encoding_name_dict.items():
token_indexer_key_to_batch_dict[indexer_name].append(tensor)
return {indexer_name: torch.stack(tensor_list)
for indexer_name, tensor_list in token_indexer_key_to_batch_dict.items()}
else:
return torch.stack(tensor_list)
评论列表
文章目录