def dihedral_transform_batch(x):
g = np.random.randint(low=0, high=8, size=x.shape[0])
h, w = x.shape[-2:]
hh = (h - 1) / 2.
hw = (w - 1) / 2.
I, J = np.meshgrid(np.linspace(-hh, hh, x.shape[-2]), np.linspace(-hw, hw, x.shape[-1]))
C = np.r_[[I, J]]
D4C = np.einsum('...ij,jkl->...ikl', D4, C)
D4C[:, 0] += hh
D4C[:, 1] += hw
D4C = D4C.astype(int)
x_out = np.empty_like(x)
for i in range(x.shape[0]):
I, J = D4C[g[i]]
x_out[i, :] = x[i][:, J, I]
return x_out
评论列表
文章目录