ewc_mnist.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-EWC 作者: okdshin 项目源码 文件源码
def train_tasks_continuosly(
        args, model, train, test, train2, test2, enable_ewc):
    # Train Task A or load trained model
    if os.path.exists("mlp_taskA.model") or args.skip_taskA:
        print("load taskA model")
        serializers.load_npz("./model50/mlp_taskA.model", model)
    else:
        print("train taskA")
        train_task(args, "train_task_a"+("_with_ewc" if enable_ewc else ""),
                   model, args.epoch, train,
                   {"TaskA": test}, args.batchsize)
        print("save the model")
        serializers.save_npz("mlp_taskA.model", model)

    if enable_ewc:
        print("enable EWC")
        model.compute_fisher(train)
        model.store_variables()

    # Train Task B
    print("train taskB")
    train_task(args, "train_task_ab"+("_with_ewc" if enable_ewc else ""),
               model, args.epoch, train2,
               {"TaskA": test, "TaskB": test2}, args.batchsize)
    print("save the model")
    serializers.save_npz(
            "mlp_taskAB"+("_with_ewc" if enable_ewc else "")+".model", model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号