models.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:punctuator2 作者: ottokart 项目源码 文件源码
def save(self, file_path, gsums=None, learning_rate=None, validation_ppl_history=None, best_validation_ppl=None, epoch=None, random_state=None):
        import cPickle
        state = {
            "type":                     self.__class__.__name__,
            "n_hidden":                 self.n_hidden,
            "x_vocabulary":             self.x_vocabulary,
            "y_vocabulary":             self.y_vocabulary,
            "stage1_model_file_name":   self.stage1_model_file_name if hasattr(self, "stage1_model_file_name") else None,
            "params":                   [p.get_value(borrow=True) for p in self.params],
            "gsums":                    [s.get_value(borrow=True) for s in gsums] if gsums else None,
            "learning_rate":            learning_rate,
            "validation_ppl_history":   validation_ppl_history,
            "epoch":                    epoch,
            "random_state":             random_state
        }

        with open(file_path, 'wb') as f:
            cPickle.dump(state, f, protocol=cPickle.HIGHEST_PROTOCOL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号