def test_flatten_broadcastable():
# Ensure that the broadcastable pattern of the output is coherent with
# that of the input
inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3)
assert out.broadcastable == (True, False, True)
评论列表
文章目录