def forward(self, x): # create diagonal matrices m = np.zeros((x.size * self.dim)).reshape(-1, self.dim, self.dim) x = x.reshape(-1, self.dim) m[(np.s_[:],) + np.diag_indices(x.shape[1])] = x return m