def proj_sentence(self, sent): output = self.sentproj(sent) output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) return output # (bsize, projdim)