def flatten(self, x_tensor): batch_size = x_tensor.shape[0] mult = 1 for a in range(1,len(x_tensor.shape)): mult = mult * int(x_tensor.shape[a]) return tf.reshape(x_tensor,[-1,mult])