def local_reshape_chain(op):
@gof.local_optimizer([op])
def f(node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if not opt.check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# optimization.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
return f
评论列表
文章目录