def test_pretrained_on_target(source_cnn, target, args):
print(":: testing pretrained source CNN on target domain")
if args.device >= 0:
source_cnn.to_gpu()
with chainer.using_config('train', False):
_, target_test_iterator = data2iterator(target, args.batchsize, multiprocess=False)
mean_accuracy = 0.0
n_batches = 0
for batch in target_test_iterator:
batch, labels = chainer.dataset.concat_examples(batch, device=args.device)
encode = source_cnn.encoder(batch)
classify = source_cnn.classifier(encode)
acc = accuracy.accuracy(classify, labels)
mean_accuracy += acc.data
n_batches += 1
mean_accuracy /= n_batches
print(":: classifier trained on only source, evaluated on target: accuracy {}%".format(mean_accuracy))
评论列表
文章目录