transforms.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DeepPoseComparison 作者: ynaka81 项目源码 文件源码
def __call__(self, image, pose, visibility):
        _, height, width = image.size()
        shape = (width, height)
        visible_pose = torch.masked_select(pose, visibility.byte()).view(-1, 2)
        p_min = visible_pose.min(0)[0].squeeze()
        p_max = visible_pose.max(0)[0].squeeze()
        p_c = (p_min + p_max)/2
        crop_shape = [0, 0, 0, 0]
        # crop on a joint center
        for i in range(2):
            if self.data_augmentation:
                crop_shape[2*i] = random.randint(0, int(min(p_min[i], shape[i] - self.crop_size)))
            else:
                crop_shape[2*i] = max(0, int(p_c[i] - float(self.crop_size)/2))
            crop_shape[2*i + 1] = min(shape[i], crop_shape[2*i] + self.crop_size)
            crop_shape[2*i] -= self.crop_size - (crop_shape[2*i + 1] - crop_shape[2*i])
        transformed_image = image[:, crop_shape[2]:crop_shape[3], crop_shape[0]:crop_shape[1]]
        p_0 = torch.Tensor((crop_shape[0], crop_shape[2])).view(1, 2).expand_as(pose)
        transformed_pose = pose - p_0
        return transformed_image, transformed_pose, visibility
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号