def _unstack_tuple(self, inputs, tensor_sizes):
size = len(tensor_sizes)
start_position = tf.constant(0)
output = []
dim = len(inputs.get_shape().as_list())-1
for i in range(size):
output.append(tf.slice(inputs, begin=[start_position, *([0]*dim)], size=[tensor_sizes[i], *([-1]*dim)]))
start_position = start_position + tensor_sizes[i]
return tf.tuple(output)
评论列表
文章目录