def __call__(self, image, pose, visibility):
_, height, width = image.size()
shape = (width, height)
visible_pose = torch.masked_select(pose, visibility.byte()).view(-1, 2)
p_min = visible_pose.min(0)[0].squeeze()
p_max = visible_pose.max(0)[0].squeeze()
p_c = (p_min + p_max)/2
crop_shape = [0, 0, 0, 0]
# crop on a joint center
for i in range(2):
if self.data_augmentation:
crop_shape[2*i] = random.randint(0, int(min(p_min[i], shape[i] - self.crop_size)))
else:
crop_shape[2*i] = max(0, int(p_c[i] - float(self.crop_size)/2))
crop_shape[2*i + 1] = min(shape[i], crop_shape[2*i] + self.crop_size)
crop_shape[2*i] -= self.crop_size - (crop_shape[2*i + 1] - crop_shape[2*i])
transformed_image = image[:, crop_shape[2]:crop_shape[3], crop_shape[0]:crop_shape[1]]
p_0 = torch.Tensor((crop_shape[0], crop_shape[2])).view(1, 2).expand_as(pose)
transformed_pose = pose - p_0
return transformed_image, transformed_pose, visibility
评论列表
文章目录