def is_floating(t): return type(t) in [torch.FloatTensor, torch.DoubleTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor]