def forward(self, ft, scaling, seg_split):
x1 = seg_split[0]
x2 = seg_split[1]
n_seg = seg_split[2]
ft_dim = ft.size()[1]
src = ft.view(-1, n_seg, ft_dim)
scaling = scaling.view(-1, 2)
n_sample = src.size()[0]
def get_stage_stpp(stage_ft, stage_parts, norm_num, scaling):
stage_stpp = []
stage_len = stage_ft.size(1)
for n_part in stage_parts:
ticks = torch.arange(0, stage_len + 1e-5, stage_len / n_part)
for i in range(n_part):
part_ft = stage_ft[:, int(ticks[i]):int(ticks[i+1]), :].mean(dim=1) / norm_num
if scaling is not None:
part_ft = part_ft * scaling.resize(n_sample, 1)
stage_stpp.append(part_ft)
return stage_stpp
feature_parts = []
feature_parts.extend(get_stage_stpp(src[:, :x1, :], self.parts[0], self.norm_num[0], scaling[:, 0])) # starting
feature_parts.extend(get_stage_stpp(src[:, x1:x2, :], self.parts[1], self.norm_num[1], None)) # course
feature_parts.extend(get_stage_stpp(src[:, x2:, :], self.parts[2], self.norm_num[2], scaling[:, 1])) # ending
stpp_ft = torch.cat(feature_parts, dim=1)
if not self.sc:
return stpp_ft, stpp_ft
else:
course_ft = src[:, x1:x2, :].mean(dim=1)
return course_ft, stpp_ft
评论列表
文章目录