def test_diag(self):
# test that it builds a matrix with given diagonal when using
# vector inputs
x = theano.tensor.vector()
y = diag(x)
assert y.owner.op.__class__ == AllocDiag
# test that it extracts the diagonal when using matrix input
x = theano.tensor.matrix()
y = extract_diag(x)
assert y.owner.op.__class__ == ExtractDiag
# other types should raise error
x = theano.tensor.tensor3()
ok = False
try:
y = extract_diag(x)
except TypeError:
ok = True
assert ok
# not testing the view=True case since it is not used anywhere.
评论列表
文章目录