train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chainer-ADDA 作者: pfnet-research 项目源码 文件源码
def main(args):
    # get datasets
    source_train, source_test = chainer.datasets.get_svhn()
    target_train, target_test = chainer.datasets.get_mnist(ndim=3, rgb_format=True)
    source = source_train, source_test

    # resize mnist to 32x32
    def transform(in_data):
        img, label = in_data
        img = resize(img, (32, 32))
        return img, label

    target_train = TransformDataset(target_train, transform)
    target_test = TransformDataset(target_test, transform)

    target = target_train, target_test

    # load pretrained source, or perform pretraining
    pretrained = os.path.join(args.output, args.pretrained_source)
    if not os.path.isfile(pretrained):
        source_cnn = pretrain_source_cnn(source, args)
    else:
        source_cnn = Loss(num_classes=10)
        serializers.load_npz(pretrained, source_cnn)

    # how well does this perform on target domain?
    test_pretrained_on_target(source_cnn, target, args)

    # initialize the target cnn (do not use source_cnn.copy)
    target_cnn = Loss(num_classes=10)
    # copy parameters from source cnn to target cnn
    target_cnn.copyparams(source_cnn)

    train_target_cnn(source, target, source_cnn, target_cnn, args)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号