def _split(self, concat, n_filter, x): if concat or type(x) != tuple: x1 = x[:, :, :, :n_filter // 2] x2 = x[:, :, :, n_filter // 2:] else: x1, x2 = x return x1, x2