def call(self, input, state): concat_input = tf.concat((self._flat_parent_state, input), axis=1) return self._wrapped.call(concat_input, state)