def proj_image(self, img): output = self.imgproj(img) output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) return output # (bsize, projdim)