def _mode(self): ret = math_ops.argmax(self.logits, dimension=self._batch_rank) ret = math_ops.cast(ret, self.dtype) ret.set_shape(self.get_batch_shape()) return ret