model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensorlm 作者: batzner 项目源码 文件源码
def _sample_step(self, session, inputs, update_state=True):
        """Feeds batch inputs to the model and returns the batch output ids.

        Args:
            session (tf.Session): The TF session to run the operations in.
            inputs (np.ndarray): A batch of inputs. Must have the shape (batch_size, num_timesteps)
                and contain only integers. The batch size and number of timesteps are determined
                dynamically, so the shape of inputs can vary between calls of this function.
            update_state (bool): If True, the LSTM's memory state will be updated after feeding the
                batch inputs, so that the LSTM will use this state before the next feed of inputs.
                If this function gets called during training, make sure to call it between
                on_pause_training and will_resume_training. Thus, the training's memory state will
                be frozen before and unfrozen after this function call.

        Returns:
            np.ndarray: A batch of outputs with the same shape and data type as the inputs
                parameter.
        """
        # Feed the input
        feed_dict = {self._inputs: inputs}
        runs = [self._logits, self._update_state_op if update_state else tf.no_op()]

        # Get the output
        logits, _ = session.run(runs, feed_dict=feed_dict)
        return np.argmax(logits, axis=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号