def ensureRank3(input):
"""Ensures the input tensor has a rank of 3, otherwise it reshapes the tensor"""
if (len(input.get_shape()) == 2):
input = tf.expand_dims(input, 2)
tf.assert_rank(input, 3, message="Tensor is not rank 3")
return input
Utilities.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录