def shuffle_columns(x, srng):
'''Shuffles a tensor along the second index.
Args:
x (T.tensor).
srng (sharedRandomstream).
'''
def step_shuffle(m, perm):
return m[perm]
perm_mat = srng.permutation(n=x.shape[0], size=(x.shape[1],))
y, _ = scan(
step_shuffle, [x.transpose(1, 0, 2), perm_mat], [None], [], x.shape[1],
name='shuffle', strict=False)
return y.transpose(1, 0, 2)
评论列表
文章目录