def zero_state(self, batch_size, dtype=tf.float32): zeros = tf.zeros((batch_size, self._num_cells), dtype=dtype) return LSTMStateTuple(zeros, zeros)