def shuffle_jointly(*args):
'''
accepts n args, concatinates them all together
and then shuffles along batch_dim and returns them unsplit
'''
shps = [a.get_shape().as_list()[-1] for a in args]
concated = tf.random_shuffle(tf.concat(values=args, axis=1))
splits = []
current_max = 0
for begin in shps:
splits.append(concated[:, current_max:current_max + begin])
current_max += begin
return splits
评论列表
文章目录