train.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:chainer-LSGAN 作者: pfnet-research 项目源码 文件源码
def main(args):

    # if we enabled GPU mode, set the GPU to use
    if args.device_id >= 0:
        chainer.cuda.get_device(args.device_id).use()

    # Load dataset (we will only use the training set)
    if args.mnist:
        train, test = chainer.datasets.get_mnist(withlabel=False, scale=2, ndim=3)
        generator = GeneratorMNIST()
    else:
        train, test = chainer.datasets.get_cifar10(withlabel=False, scale=2, ndim=3)
        generator = GeneratorCIFAR()

    # subtracting 1, after scaling to 2 (done above) will make all pixels in the range [-1,1]
    train -= 1.0

    num_training_samples = train.shape[0]

    # make data iterators
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # build optimizers and models
    opt_generator = chainer.optimizers.RMSprop(lr=args.learning_rate)
    opt_discriminator = chainer.optimizers.RMSprop(lr=args.learning_rate)

    opt_generator.setup(generator)
    opt_discriminator.setup(Discriminator())

    # make a random noise iterator (uniform noise between -1 and 1)
    noise_iter = RandomNoiseIterator(UniformNoiseGenerator(-1, 1, args.num_z), args.batchsize)

    # send to GPU
    if args.device_id >= 0:
        opt_generator.target.to_gpu()
        opt_discriminator.target.to_gpu()

    # make the output folder
    if not os.path.exists(args.output):
        os.makedirs(args.output, exist_ok=True)

    print("Starting training loop...")

    while train_iter.epoch < args.num_epochs:
        training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator)

    print("Finished training.")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号