def excg_rgb_bgr(batch): batch=batch.transpose(0,1) (r,g,b) = torch.chunk(batch, 3) batch = torch.cat((b,g,r)) batch = batch.transpose(0,1) return batch # Save model