model.py 文件源码

python
阅读 112 收藏 0 点赞 0 评论 0

项目:DialogueBreakdownDetection2016 作者: icoxfog417 项目源码 文件源码
def build(self, session, predict=True, projection=True):
        for j, bucket in enumerate(self.buckets):
            with vs.variable_scope(vs.get_variable_scope(), reuse=True if j > 0 else None):
                o, d_s, e_s = self.model.forward(
                    self.encoder_inputs[:bucket[0]], self.decoder_inputs[:bucket[1]], predict=predict, projection=projection
                )
                self._outputs.append(o)
                self._encoder_state.append(e_s)
                self._decoder_state.append(d_s)

        self.saver = tf.train.Saver(tf.all_variables())
        session.run(tf.initialize_all_variables())
        if self.model_path:
            saved = tf.train.get_checkpoint_state(self.model_path)
            if saved and tf.gfile.Exists(saved.model_checkpoint_path):
                self.saver.restore(session, saved.model_checkpoint_path)
        self._graph_builded = True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号