def Identity(n, dtype=None, name=None):
"""Identity matrix of size n."""
if hasattr(n, "shape"): # got a Tensor
nn = fix_shape(n.shape)
assert nn[0] == nn[1]
n = nn[0]
if not dtype:
dtype = default_dtype
return tf.diag(tf.ones((n,), dtype=dtype), name=name)
评论列表
文章目录