def test_local_add_specialize():
# test of non-zero dimension
a = tensor.vector()
s = tensor.add(tensor.zeros_like(a))
assert local_add_specialize.transform(s.owner)
# test of 0-d
a = tensor.scalar()
s = tensor.add(tensor.zeros_like(a))
assert local_add_specialize.transform(s.owner)
# Test when the 0 input is forcing upcasting
a = tensor.constant(0, dtype='int64')
b = tensor.constant(1, dtype='int32')
s = a + b
transformed = local_add_specialize.transform(s.owner)
assert transformed
assert transformed[0].type == s.type
评论列表
文章目录