def checkOneHot(self):
v = torch.LongTensor([1, 2, 1, 2, 0])
v_length = torch.LongTensor([2, 3])
v_onehot = utils.oneHot(v, v_length, 4)
target = torch.FloatTensor([[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
[[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]])
assert target.equal(v_onehot)
评论列表
文章目录