def is_same_dynamic_shape(x, y):
"""
Whether `x` and `y` has the same dynamic shape.
:param x: A Tensor.
:param y: A Tensor.
:return: A scalar Tensor of `bool`.
"""
# There is a BUG of Tensorflow for not doing static shape inference
# right in nested tf.cond()'s, so we are not comparing x and y's
# shape directly but working with their concatenations.
return tf.cond(
tf.equal(tf.rank(x), tf.rank(y)),
lambda: tf.reduce_all(tf.equal(
tf.concat([tf.shape(x), tf.shape(y)], 0),
tf.concat([tf.shape(y), tf.shape(x)], 0))),
lambda: tf.convert_to_tensor(False, tf.bool))
评论列表
文章目录