def random_column(columns):
"""Zeros out all except one of `columns`.
Used for rounds with global drop path.
Args:
columns: the columns of a fractal block to be selected from.
"""
num_columns = tensor_shape(columns)[0]
mask = tf.random_shuffle([True]+[False]*(num_columns-1))
return apply_mask(mask, columns)* num_columns
评论列表
文章目录