def row_index(shape):
"""
Generate an X index for the given tensor.
.. code-block:: python
[
[ 0, 1, 2, ... width-1 ],
[ 0, 1, 2, ... width-1 ],
... (x height)
]
:param list[int] shape:
:return: Tensor
"""
height = shape[0]
width = shape[1]
row_identity = tf.cumsum(tf.ones([width], dtype=tf.int32), exclusive=True)
row_identity = tf.reshape(tf.tile(row_identity, [height]), [height, width])
return row_identity
评论列表
文章目录