def _create(self): zero_state = nest.map_structure( lambda x: tf.zeros([self.batch_size, x], dtype=tf.float32), self.decoder_state_size) return zero_state