train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:chainer-ADDA 作者: pfnet-research 项目源码 文件源码
def pretrain_source_cnn(data, args, epochs=1000):
    print(":: pretraining source encoder")
    source_cnn = Loss(num_classes=10)
    if args.device >= 0:
        source_cnn.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(source_cnn)

    train_iterator, test_iterator = data2iterator(data, args.batchsize, multiprocess=False)

    # train_iterator = chainer.iterators.MultiprocessIterator(data, args.batchsize, n_processes=4)

    updater = chainer.training.StandardUpdater(iterator=train_iterator, optimizer=optimizer, device=args.device)
    trainer = chainer.training.Trainer(updater, (epochs, 'epoch') ,out=args.output)

    # learning rate decay
    # trainer.extend(extensions.ExponentialShift("alpha", rate=0.9, init=args.learning_rate, target=args.learning_rate*10E-5))

    trainer.extend(extensions.Evaluator(test_iterator, source_cnn, device=args.device))
    # trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, "epoch"))
    trainer.extend(extensions.snapshot_object(optimizer.target, "source_model_epoch_{.updater.epoch}"), trigger=(epochs, "epoch"))

    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.LogReport(trigger=(1, "epoch")))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    trainer.run()

    return source_cnn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号