def assert_broadcastable(low_tensor, high_tensor):
low_shape = tf.shape(low_tensor)
high_shape = tf.shape(high_tensor)
low_rank = tf.rank(low_tensor)
# assert that shapes are compatible
high_shape_prefix = tf.slice(high_shape, [0], [low_rank])
assert_op = tf.assert_equal(high_shape_prefix, low_shape, name="assert_shape_prefix")
return assert_op
评论列表
文章目录