def test_orthogonal(self):
for as_variable in [True, False]:
for use_gain in [True, False]:
for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]:
input_tensor = torch.zeros(tensor_size)
gain = 1.0
if as_variable:
input_tensor = Variable(input_tensor)
if use_gain:
gain = self._random_float(0.1, 2)
init.orthogonal(input_tensor, gain=gain)
else:
init.orthogonal(input_tensor)
if as_variable:
input_tensor = input_tensor.data
rows, cols = tensor_size[0], reduce(mul, tensor_size[1:])
flattened_tensor = input_tensor.view(rows, cols)
if rows > cols:
self.assertEqual(torch.mm(flattened_tensor.t(), flattened_tensor),
torch.eye(cols) * gain ** 2, prec=1e-6)
else:
self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()),
torch.eye(rows) * gain ** 2, prec=1e-6)
评论列表
文章目录