def gram_matrix(x): b, ch, h, w = x.data.shape v = F.reshape(x, (b, ch, w * h)) return F.batch_matmul(v, v, transb=True) / np.float32(ch * w * h)