train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:chainer-ADDA 作者: pfnet-research 项目源码 文件源码
def train_target_cnn(source, target, source_cnn, target_cnn, args, epochs=10000):
    print(":: training encoder with target domain")
    discriminator = Discriminator()

    if args.device >= 0:
        source_cnn.to_gpu()
        target_cnn.to_gpu()
        discriminator.to_gpu()

    # target_optimizer = chainer.optimizers.Adam(alpha=1.0E-5, beta1=0.5)
    target_optimizer = chainer.optimizers.RMSprop(lr=args.lr)
    # target_optimizer = chainer.optimizers.MomentumSGD(lr=1.0E-4, momentum=0.99)
    target_optimizer.setup(target_cnn.encoder)
    target_optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    # discriminator_optimizer = chainer.optimizers.Adam(alpha=1.0E-5, beta1=0.5)
    discriminator_optimizer = chainer.optimizers.RMSprop(lr=args.lr)
    # discriminator_optimizer = chainer.optimizers.MomentumSGD(lr=1.0E-4, momentum=0.99)
    discriminator_optimizer.setup(discriminator)
    discriminator_optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    source_train_iterator, source_test_iterator = data2iterator(source, args.batchsize, multiprocess=False)
    target_train_iterator, target_test_iterator = data2iterator(target, args.batchsize, multiprocess=False)

    updater = ADDAUpdater(source_train_iterator, target_train_iterator, source_cnn, target_optimizer, discriminator_optimizer, args)

    trainer = chainer.training.Trainer(updater, (epochs, 'epoch'), out=args.output)

    trainer.extend(extensions.Evaluator(target_test_iterator, target_cnn, device=args.device))
    # trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, "epoch"))
    trainer.extend(extensions.snapshot_object(target_cnn, "target_model_epoch_{.updater.epoch}"), trigger=(epochs, "epoch"))

    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.LogReport(trigger=(1, "epoch")))
    trainer.extend(extensions.PrintReport(
        ["epoch", "loss/discrim", "loss/encoder",
         "validation/main/loss", "validation/main/accuracy", "elapsed_time"]))

    trainer.run()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号