concatenate.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def map_index(self, index):
        # Get a list of lengths of all datasets. Say the answer is [4, 3, 3],
        # and we're looking for index = 5.
        len_list = list(map(len, self.datasets))
        # Cumulate to a numpy array. The answer is [4, 7, 10]
        cumulative_len_list = np.cumsum(len_list)
        # When the index is subtracted, we get [-1, 2, 5]. We're looking for the (index
        # of the) first cumulated len which is larger than the index (in this case,
        # 7 (index 1)).
        offset_cumulative_len_list = cumulative_len_list - index
        dataset_index = np.argmax(offset_cumulative_len_list > 0)
        # With the dataset index, we figure out the index in dataset
        if dataset_index == 0:
            # First dataset - index corresponds to index_in_dataset
            index_in_dataset = index
        else:
            # Get cumulated length up to the current dataset
            len_up_to_dataset = cumulative_len_list[dataset_index - 1]
            # Compute index_in_dataset as that what's left
            index_in_dataset = index - len_up_to_dataset
        return dataset_index, index_in_dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号