train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-ADDA 作者: pfnet-research 项目源码 文件源码
def test_pretrained_on_target(source_cnn, target, args):
    print(":: testing pretrained source CNN on target domain")

    if args.device >= 0:
        source_cnn.to_gpu()

    with chainer.using_config('train', False):
        _, target_test_iterator = data2iterator(target, args.batchsize, multiprocess=False)

        mean_accuracy = 0.0
        n_batches = 0

        for batch in target_test_iterator:
            batch, labels = chainer.dataset.concat_examples(batch, device=args.device)
            encode = source_cnn.encoder(batch)
            classify = source_cnn.classifier(encode)
            acc = accuracy.accuracy(classify, labels)
            mean_accuracy += acc.data
            n_batches += 1
        mean_accuracy /= n_batches

        print(":: classifier trained on only source, evaluated on target: accuracy {}%".format(mean_accuracy))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号