create_checkpoint.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:stn-ocr 作者: Bartzi 项目源码 文件源码
def get_create_checkpoint_callback(iteration, model_prefix):

    def create_checkpoint(execution_params):
        if execution_params.nbatch % iteration == 0:
            original_executor = execution_params.locals['executor_manager']
            save_dict = {('arg:%s' % k): v[0].as_in_context(mx.cpu()) for k, v in zip(original_executor.param_names, original_executor.param_arrays)}
            save_dict.update({('aux:%s' % k): v[0].as_in_context(mx.cpu()) for k, v in zip(original_executor.aux_names, original_executor.aux_arrays)})

            symbol = execution_params.locals['symbol']
            symbol.save('{}-symbol.json'.format(model_prefix))

            model_name = "{}-{:0>4}-{:0>5}".format(model_prefix, execution_params.epoch, execution_params.nbatch)
            mx.nd.save(
                model_name,
                save_dict,
            )
            logging.info('Saved checkpoint to \"{}\"'.format(model_name))

    return create_checkpoint
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号