particleGAN.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:crayimage 作者: yandexdataschool 项目源码 文件源码
def _train_procedures(self):
    self.learning_rate = T.fscalar('learning rate')
    self.grads_generator = theano.grad(self.loss_generator, self.params_generator)

    self.grads_generator_clipped = updates.total_norm_constraint(
      self.grads_generator, max_norm=self.grad_clip_norm
    )

    upd_generator = updates.sgd(
      self.grads_generator_clipped, self.params_generator,
      learning_rate=self.learning_rate
    )

    self.train_generator = theano.function(
      [self.X_geant_raw, self.learning_rate],
      self.loss_pseudo,
      updates=upd_generator
    )

    self.grads_discriminator = theano.grad(self.loss_discriminator, self.params_discriminator)

    self.grads_discriminator_clipped = updates.total_norm_constraint(
      self.grads_discriminator, max_norm=self.grad_clip_norm
    )

    upd_discriminator = updates.sgd(
      self.grads_discriminator_clipped, self.params_discriminator,
      learning_rate=self.learning_rate
    )

    self.train_discriminator = theano.function(
      [self.X_geant_raw, self.X_real_raw, self.learning_rate],
      [self.loss_pseudo, self.loss_real],
      updates=upd_discriminator
    )

    self.anneal_discriminator = nn.updates.sa(
      [self.X_geant_raw, self.X_real_raw], self.loss_discriminator,
      params=self.params_discriminator,
      **self.annealing_args
    )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号