def train_target_cnn(source, target, source_cnn, target_cnn, args, epochs=10000):
print(":: training encoder with target domain")
discriminator = Discriminator()
if args.device >= 0:
source_cnn.to_gpu()
target_cnn.to_gpu()
discriminator.to_gpu()
# target_optimizer = chainer.optimizers.Adam(alpha=1.0E-5, beta1=0.5)
target_optimizer = chainer.optimizers.RMSprop(lr=args.lr)
# target_optimizer = chainer.optimizers.MomentumSGD(lr=1.0E-4, momentum=0.99)
target_optimizer.setup(target_cnn.encoder)
target_optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))
# discriminator_optimizer = chainer.optimizers.Adam(alpha=1.0E-5, beta1=0.5)
discriminator_optimizer = chainer.optimizers.RMSprop(lr=args.lr)
# discriminator_optimizer = chainer.optimizers.MomentumSGD(lr=1.0E-4, momentum=0.99)
discriminator_optimizer.setup(discriminator)
discriminator_optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))
source_train_iterator, source_test_iterator = data2iterator(source, args.batchsize, multiprocess=False)
target_train_iterator, target_test_iterator = data2iterator(target, args.batchsize, multiprocess=False)
updater = ADDAUpdater(source_train_iterator, target_train_iterator, source_cnn, target_optimizer, discriminator_optimizer, args)
trainer = chainer.training.Trainer(updater, (epochs, 'epoch'), out=args.output)
trainer.extend(extensions.Evaluator(target_test_iterator, target_cnn, device=args.device))
# trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, "epoch"))
trainer.extend(extensions.snapshot_object(target_cnn, "target_model_epoch_{.updater.epoch}"), trigger=(epochs, "epoch"))
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.extend(extensions.LogReport(trigger=(1, "epoch")))
trainer.extend(extensions.PrintReport(
["epoch", "loss/discrim", "loss/encoder",
"validation/main/loss", "validation/main/accuracy", "elapsed_time"]))
trainer.run()
评论列表
文章目录