extensions.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:dl4mt-multi 作者: nyu-dl 项目源码 文件源码
def _load_accumulators(self, main_loop):
        """Nasty method, use carefully"""
        for cg_name, model in main_loop.models.iteritems():
            source = numpy.load(self.path_to_accumulators.format(cg_name))
            accums_dict = {name.replace("-", "/"): value
                           for name, value in source.items()}
            source.close()
            algo = main_loop.algorithm.algorithms[cg_name]
            model_params = model.get_params()
            steps = algo.steps.items()

            for pidx in xrange(len(steps)):
                # Get parameter name and its accumulators
                p = steps[pidx][0]
                name = [k for k, v in model_params.iteritems() if v == p][0]
                accums = accums_dict[name]

                # This is num_accums_per_param
                col = len(accums)
                for aidx in xrange(col):
                    algo.step_rule_updates[pidx*col+aidx][0].set_value(
                        accums[aidx])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号