def _new_idx(self, input): if torch.typename(input) == 'torch.cuda.FloatTensor': return torch.cuda.ByteTensor() else: return torch.ByteTensor()