def get_output_for(self, inputs, **kwargs): x, hprev = inputs n_batch = x.shape[0] x = x.reshape((n_batch, -1)) return self._gru_layer.step(x, hprev)