def test_batch_left_interp_on_a_vector():
vector = torch.randn(6)
actual = torch.matmul(batch_interp_matrix, vector.unsqueeze(-1).unsqueeze(0)).squeeze(0)
res = left_interp(batch_interp_indices, batch_interp_values, Variable(vector)).data
assert approx_equal(res, actual)
评论列表
文章目录