def test_matrix_col(self):
a = vector('a')
b = matrix('b')
g = optimize(FunctionGraph(
[a, b],
[tensor.dot(b, a.dimshuffle(0, 'x')).T]),
level='stabilize')
sg = '[dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b))]'
assert str(g) == sg, (str(g), sg)
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(g, ops_to_check='all'))
评论列表
文章目录