def train_tasks_continuosly(
args, model, train, test, train2, test2, enable_ewc):
# Train Task A or load trained model
if os.path.exists("mlp_taskA.model") or args.skip_taskA:
print("load taskA model")
serializers.load_npz("./model50/mlp_taskA.model", model)
else:
print("train taskA")
train_task(args, "train_task_a"+("_with_ewc" if enable_ewc else ""),
model, args.epoch, train,
{"TaskA": test}, args.batchsize)
print("save the model")
serializers.save_npz("mlp_taskA.model", model)
if enable_ewc:
print("enable EWC")
model.compute_fisher(train)
model.store_variables()
# Train Task B
print("train taskB")
train_task(args, "train_task_ab"+("_with_ewc" if enable_ewc else ""),
model, args.epoch, train2,
{"TaskA": test, "TaskB": test2}, args.batchsize)
print("save the model")
serializers.save_npz(
"mlp_taskAB"+("_with_ewc" if enable_ewc else "")+".model", model)
评论列表
文章目录