def ensureRank2(input):
"""Ensures the input tensor has a rank of 2, otherwise it reshapes the tensor"""
if (len(input.get_shape()) == 3):
input = tf.reshape(input, ([-1, int(input.get_shape()[1])]))
tf.assert_rank(input, 2, message="Tensor is not rank 2")
return input
Utilities.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录