def main(args):
# get datasets
source_train, source_test = chainer.datasets.get_svhn()
target_train, target_test = chainer.datasets.get_mnist(ndim=3, rgb_format=True)
source = source_train, source_test
# resize mnist to 32x32
def transform(in_data):
img, label = in_data
img = resize(img, (32, 32))
return img, label
target_train = TransformDataset(target_train, transform)
target_test = TransformDataset(target_test, transform)
target = target_train, target_test
# load pretrained source, or perform pretraining
pretrained = os.path.join(args.output, args.pretrained_source)
if not os.path.isfile(pretrained):
source_cnn = pretrain_source_cnn(source, args)
else:
source_cnn = Loss(num_classes=10)
serializers.load_npz(pretrained, source_cnn)
# how well does this perform on target domain?
test_pretrained_on_target(source_cnn, target, args)
# initialize the target cnn (do not use source_cnn.copy)
target_cnn = Loss(num_classes=10)
# copy parameters from source cnn to target cnn
target_cnn.copyparams(source_cnn)
train_target_cnn(source, target, source_cnn, target_cnn, args)
评论列表
文章目录