def unravel_index(indices, shape):
with tf.name_scope('unravel_index'):
indices = tf.expand_dims(indices, 0)
shape = tf.expand_dims(shape, 1)
strides_shifted = tf.cumprod(shape, exclusive=True, reverse=True)
res = (indices // strides_shifted) % shape
return tf.transpose(res, (1, 0))
# TODO: get rid of this when TF fixes the NaN bugs in tf.svd:
# https://github.com/tensorflow/tensorflow/issues/8905
评论列表
文章目录