field.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def batch_tensors(cls, tensor_list: List[DataArray]) -> DataArray:  # type: ignore
        """
        Takes the output of ``Field.as_tensor()`` from a list of ``Instances`` and merges it into
        one batched tensor for this ``Field``.  The default implementation here in the base class
        handles cases where ``as_tensor`` returns a single torch tensor per instance, or a
        dictionary of single tensors.  If your subclass returns something other than this, you need
        to override this method.
        """
        if isinstance(tensor_list[0], dict):
            # This is creating a dict of {token_indexer_key: batch_tensor} for each
            # token indexer used to index this field. This is mostly utilised by TextFields.
            token_indexer_key_to_batch_dict: Dict[str, List[torch.Tensor]] = defaultdict(list)
            for encoding_name_dict in tensor_list:
                for indexer_name, tensor in encoding_name_dict.items():
                    token_indexer_key_to_batch_dict[indexer_name].append(tensor)
            return {indexer_name: torch.stack(tensor_list)
                    for indexer_name, tensor_list in token_indexer_key_to_batch_dict.items()}
        else:
            return torch.stack(tensor_list)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号