def _flip_vector_to_matrix_dynamic(vec, batch_shape):
"""flip_vector_to_matrix with dynamic shapes."""
# Shapes associated with batch_shape
batch_rank = array_ops.size(batch_shape)
# Shapes associated with vec.
vec = ops.convert_to_tensor(vec, name="vec")
vec_shape = array_ops.shape(vec)
vec_rank = array_ops.rank(vec)
vec_batch_rank = vec_rank - 1
m = vec_batch_rank - batch_rank
# vec_shape_left = [M1,...,Mm] or [].
vec_shape_left = array_ops.slice(vec_shape, [0], [m])
# If vec_shape_left = [], then condensed_shape = [1] since reduce_prod([]) = 1
# If vec_shape_left = [M1,...,Mm], condensed_shape = [M1*...*Mm]
condensed_shape = [math_ops.reduce_prod(vec_shape_left)]
k = array_ops.gather(vec_shape, vec_rank - 1)
new_shape = array_ops.concat(0, (batch_shape, [k], condensed_shape))
def _flip_front_dims_to_back():
# Permutation corresponding to [N1,...,Nn] + [k, M1,...,Mm]
perm = array_ops.concat(
0, (math_ops.range(m, vec_rank), math_ops.range(0, m)))
return array_ops.transpose(vec, perm=perm)
x_flipped = control_flow_ops.cond(
math_ops.less(0, m),
_flip_front_dims_to_back,
lambda: array_ops.expand_dims(vec, -1))
return array_ops.reshape(x_flipped, new_shape)
评论列表
文章目录