def test_eye(self):
for as_variable in [True, False]:
input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5, as_variable=as_variable)
init.eye(input_tensor)
if as_variable:
input_tensor = input_tensor.data
# Check every single element
for i in range(input_tensor.size(0)):
for j in range(input_tensor.size(1)):
if i == j:
assert input_tensor[i][j] == 1
else:
assert input_tensor[i][j] == 0
评论列表
文章目录