ssn_ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:action-detection 作者: yjxiong 项目源码 文件源码
def forward(self, ft, scaling, seg_split):
        x1 = seg_split[0]
        x2 = seg_split[1]
        n_seg = seg_split[2]
        ft_dim = ft.size()[1]

        src = ft.view(-1, n_seg, ft_dim)
        scaling = scaling.view(-1, 2)
        n_sample = src.size()[0]

        def get_stage_stpp(stage_ft, stage_parts, norm_num, scaling):
            stage_stpp = []
            stage_len = stage_ft.size(1)
            for n_part in stage_parts:
                ticks = torch.arange(0, stage_len + 1e-5, stage_len / n_part)
                for i in range(n_part):
                    part_ft = stage_ft[:, int(ticks[i]):int(ticks[i+1]), :].mean(dim=1) / norm_num
                    if scaling is not None:
                        part_ft = part_ft * scaling.resize(n_sample, 1)
                    stage_stpp.append(part_ft)
            return stage_stpp

        feature_parts = []
        feature_parts.extend(get_stage_stpp(src[:, :x1, :], self.parts[0], self.norm_num[0], scaling[:, 0]))  # starting
        feature_parts.extend(get_stage_stpp(src[:, x1:x2, :], self.parts[1], self.norm_num[1], None))  # course
        feature_parts.extend(get_stage_stpp(src[:, x2:, :], self.parts[2], self.norm_num[2], scaling[:, 1]))  # ending
        stpp_ft = torch.cat(feature_parts, dim=1)
        if not self.sc:
            return stpp_ft, stpp_ft
        else:
            course_ft = src[:, x1:x2, :].mean(dim=1)
            return course_ft, stpp_ft
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号