def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None)