def _sample_step(self, session, inputs, update_state=True):
"""Feeds batch inputs to the model and returns the batch output ids.
Args:
session (tf.Session): The TF session to run the operations in.
inputs (np.ndarray): A batch of inputs. Must have the shape (batch_size, num_timesteps)
and contain only integers. The batch size and number of timesteps are determined
dynamically, so the shape of inputs can vary between calls of this function.
update_state (bool): If True, the LSTM's memory state will be updated after feeding the
batch inputs, so that the LSTM will use this state before the next feed of inputs.
If this function gets called during training, make sure to call it between
on_pause_training and will_resume_training. Thus, the training's memory state will
be frozen before and unfrozen after this function call.
Returns:
np.ndarray: A batch of outputs with the same shape and data type as the inputs
parameter.
"""
# Feed the input
feed_dict = {self._inputs: inputs}
runs = [self._logits, self._update_state_op if update_state else tf.no_op()]
# Get the output
logits, _ = session.run(runs, feed_dict=feed_dict)
return np.argmax(logits, axis=2)
评论列表
文章目录