def to_unpacked_coordinates(ix, l, bound):
ix = tf.cast(ix, tf.int32)
# You can actually compute the lens in closed form:
# lens = tf.floor(0.5 * (-tf.sqrt(4 * tf.square(l) + 4 * l - 8 * ix + 1) + 2 * l + 1))
# but it is very ugly and rounding errors could cause problems, so this approach seems safer
lens = []
for i in range(bound):
lens.append(tf.fill((l - i,), i))
lens = tf.concat(lens, axis=0)
lens = tf.gather(lens, ix)
answer_start = ix - l * lens + lens * (lens - 1) // 2
return tf.stack([answer_start, answer_start+lens], axis=1)
评论列表
文章目录