def _create(self):
d_loss = gan.graph.d_loss
g_loss = gan.graph.g_loss
g_lr = np.float32(config.g_learn_rate)
d_lr = np.float32(config.d_learn_rate)
gan.graph.d_vars = d_vars
g_defk = {k[2:]: v for k, v in config.items() if k[2:] in inspect.getargspec(config.g_trainer).args and k.startswith("d_")}
d_defk = {k[2:]: v for k, v in config.items() if k[2:] in inspect.getargspec(config.d_trainer).args and k.startswith("g_")}
g_optimizer = config.g_trainer(g_lr, **g_defk)
d_optimizer = config.d_trainer(d_lr, **d_defk)
if(config.clipped_gradients):
g_optimizer = capped_optimizer(g_optimizer, config.clipped_gradients, g_loss, g_vars)
d_optimizer = capped_optimizer(d_optimizer, config.clipped_gradients, d_loss, d_vars)
else:
g_optimizer = g_optimizer.minimize(g_loss, var_list=g_vars)
d_optimizer = d_optimizer.minimize(d_loss, var_list=d_vars)
gan.graph.clip = [tf.assign(d,tf.clip_by_value(d, -config.d_clipped_weights, config.d_clipped_weights)) for d in d_vars]
return g_optimizer, d_optimizer
评论列表
文章目录