train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chainer-LSGAN 作者: pfnet-research 项目源码 文件源码
def training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator):

    noise_samples = get_batch(noise_iter, args.device_id)

    # generate an image
    generated = opt_generator.target(noise_samples)

    # get a batch of the dataset
    train_samples = get_batch(train_iter, args.device_id)

    # update the discriminator
    Dreal = opt_discriminator.target(train_samples)
    Dgen = opt_discriminator.target(generated)

    Dloss = 0.5 * (F.sum((Dreal - 1.0)**2) + F.sum(Dgen**2)) / args.batchsize
    update_model(opt_discriminator, Dloss)

    # update the generator
    noise_samples = get_batch(noise_iter, args.device_id)
    generated = opt_generator.target(noise_samples)
    Gloss = 0.5 * F.sum((opt_discriminator.target(generated) - 1.0)**2) / args.batchsize
    update_model(opt_generator, Gloss)

    if train_iter.is_new_epoch:
        print("[{}] Discriminator loss: {} Generator loss: {}".format(train_iter.epoch, Dloss.data, Gloss.data))
        print_sample(os.path.join(args.output, "epoch_{}.png".format(train_iter.epoch)), noise_samples, opt_generator)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号