def th_repeat(a, repeats, axis=0): """Torch version of np.repeat for 1D""" assert len(a.size()) == 1 return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))