def forward(self, add_matrix, vector1, vector2):
self.save_for_backward(vector1, vector2)
output = self._get_output(add_matrix)
return torch.addr(output, self.alpha, add_matrix, self.beta,
vector1, vector2)
python类addr()的实例源码
def forward(ctx, add_matrix, vector1, vector2, alpha=1, beta=1, inplace=False):
ctx.alpha = alpha
ctx.beta = beta
ctx.save_for_backward(vector1, vector2)
output = _get_output(ctx, add_matrix, inplace=inplace)
return torch.addr(alpha, add_matrix, beta,
vector1, vector2, out=output)
def test_functional_blas(self):
def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args)
self.assertEqual(fn(*args).data, fn(*unpacked_args))
def test_blas_add(fn, x, y, z):
# Checks all signatures
compare(fn, x, y, z)
compare(fn, 0.5, x, y, z)
compare(fn, 0.5, x, 0.25, y, z)
def test_blas(fn, x, y):
compare(fn, x, y)
test_blas(torch.mm, Variable(torch.randn(2, 10)),
Variable(torch.randn(10, 4)))
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas(torch.mv, Variable(torch.randn(2, 10)),
Variable(torch.randn(10)))
test_blas_add(torch.addmv, Variable(torch.randn(2)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
test_blas(torch.ger, Variable(torch.randn(5)),
Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6)))
def forward(ctx, add_matrix, vector1, vector2, alpha=1, beta=1, inplace=False):
ctx.alpha = alpha
ctx.beta = beta
ctx.save_for_backward(vector1, vector2)
output = _get_output(ctx, add_matrix, inplace=inplace)
return torch.addr(alpha, add_matrix, beta,
vector1, vector2, out=output)
def test_functional_blas(self):
def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args)
self.assertEqual(fn(*args).data, fn(*unpacked_args))
def test_blas_add(fn, x, y, z):
# Checks all signatures
compare(fn, x, y, z)
compare(fn, 0.5, x, y, z)
compare(fn, 0.5, x, 0.25, y, z)
def test_blas(fn, x, y):
compare(fn, x, y)
test_blas(torch.mm, Variable(torch.randn(2, 10)),
Variable(torch.randn(10, 4)))
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas(torch.mv, Variable(torch.randn(2, 10)),
Variable(torch.randn(10)))
test_blas_add(torch.addmv, Variable(torch.randn(2)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
test_blas(torch.ger, Variable(torch.randn(5)),
Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6)))
def forward(ctx, add_matrix, vector1, vector2, alpha=1, beta=1, inplace=False):
ctx.alpha = alpha
ctx.beta = beta
ctx.add_matrix_size = add_matrix.size()
ctx.save_for_backward(vector1, vector2)
output = _get_output(ctx, add_matrix, inplace=inplace)
return torch.addr(alpha, add_matrix, beta,
vector1, vector2, out=output)
def test_addr(self):
types = {
'torch.DoubleTensor': 1e-8,
'torch.FloatTensor': 1e-4,
}
def run_test(m, v1, v2, m_transform=lambda x: x):
m = m_transform(m.clone())
ref = m.clone()
torch.addr(m, v1, v2, out=m)
for i in range(m.size(0)):
for j in range(m.size(1)):
ref[i, j] += v1[i] * v2[j]
self.assertEqual(m, ref)
for tname, _prec in types.items():
for h, w in [(100, 110), (1, 20), (200, 2)]:
m = torch.randn(h, w).type(tname)
v1 = torch.randn(h).type(tname)
v2 = torch.randn(w).type(tname)
run_test(m, v1, v2)
# test transpose
run_test(m, v2, v1, lambda x: x.transpose(0, 1))
# test 0 strided
v1 = torch.randn(1).type(tname).expand(h)
run_test(m, v1, v2)
run_test(m, v2, v1, lambda x: x.transpose(0, 1))
def _test_broadcast_fused_matmul(self, cast):
fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
for fn in fns:
batch_dim = random.randint(1, 8)
n_dim = random.randint(1, 8)
m_dim = random.randint(1, 8)
p_dim = random.randint(1, 8)
def dims_full_for_fn():
if fn == "baddbmm":
return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
elif fn == "addbmm":
return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
elif fn == "addmm":
return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
elif fn == "addmv":
return ([n_dim], [n_dim, m_dim], [m_dim])
elif fn == "addr":
return ([n_dim, m_dim], [n_dim], [m_dim])
else:
raise AssertionError("unknown function")
(t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
(t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
t0_small = cast(torch.randn(*t0_dims_small).float())
t1 = cast(torch.randn(*t1_dims).float())
t2 = cast(torch.randn(*t2_dims).float())
t0_full = cast(t0_small.expand(*t0_dims_full))
fntorch = getattr(torch, fn)
r0 = fntorch(t0_small, t1, t2)
r1 = fntorch(t0_full, t1, t2)
self.assertEqual(r0, r1)
def test_functional_blas(self):
def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args)
unpacked_result = fn(*unpacked_args)
packed_result = fn(*args).data
# if non-Variable torch function returns a scalar, compare to scalar
if not torch.is_tensor(unpacked_result):
assert packed_result.dim() == 1
assert packed_result.nelement() == 1
packed_result = packed_result[0]
self.assertEqual(packed_result, unpacked_result)
def test_blas_add(fn, x, y, z):
# Checks all signatures
compare(fn, x, y, z)
compare(fn, 0.5, x, y, z)
compare(fn, 0.5, x, 0.25, y, z)
def test_blas(fn, x, y):
compare(fn, x, y)
test_blas(torch.mm, Variable(torch.randn(2, 10)),
Variable(torch.randn(10, 4)))
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas(torch.mv, Variable(torch.randn(2, 10)),
Variable(torch.randn(10)))
test_blas_add(torch.addmv, Variable(torch.randn(2)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
test_blas(torch.ger, Variable(torch.randn(5)),
Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6)))
test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6)))
test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4)))
test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6)))
test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4)))
test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4)))
test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10)))
test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4)))
def test_functional_blas(self):
def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args)
unpacked_result = fn(*unpacked_args)
packed_result = fn(*args).data
# if non-Variable torch function returns a scalar, compare to scalar
if not torch.is_tensor(unpacked_result):
assert packed_result.dim() == 1
assert packed_result.nelement() == 1
packed_result = packed_result[0]
self.assertEqual(packed_result, unpacked_result)
def test_blas_add(fn, x, y, z):
# Checks all signatures
compare(fn, x, y, z)
compare(fn, 0.5, x, y, z)
compare(fn, 0.5, x, 0.25, y, z)
def test_blas(fn, x, y):
compare(fn, x, y)
test_blas(torch.mm, Variable(torch.randn(2, 10)),
Variable(torch.randn(10, 4)))
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas(torch.mv, Variable(torch.randn(2, 10)),
Variable(torch.randn(10)))
test_blas_add(torch.addmv, Variable(torch.randn(2)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
test_blas(torch.ger, Variable(torch.randn(5)),
Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6)))
test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6)))
test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4)))
test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6)))
test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4)))
test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4)))
test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10)))
test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4)))