model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:GAN-Sentence 作者: huseinzol05 项目源码 文件源码
def __init__(self, num_layers, size_layer, dimension_input, len_noise, sequence_size, learning_rate):
        self.noise = tf.placeholder(tf.float32, [None, None, len_noise])
        self.fake_input = tf.placeholder(tf.float32, [None, None, dimension_input])
        self.true_sentence = tf.placeholder(tf.float32, [None, None, dimension_input])
        self.initial_layer = generator_encode(self.noise, num_layers, size_layer, len_noise)
        self.final_outputs = generator_sentence(self.fake_input, self.initial_layer, num_layers, size_layer, dimension_input)
        fake_logits = discriminator(self.final_outputs, num_layers, size_layer, dimension_input)
        true_logits = discriminator(self.true_sentence, num_layers, size_layer, dimension_input, reuse = True)
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = true_logits, labels = tf.ones_like(true_logits)))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = fake_logits, labels = tf.zeros_like(fake_logits)))
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = fake_logits, labels = tf.ones_like(fake_logits)))

        self.d_loss = d_loss_real + d_loss_fake
        d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = 'discriminator')
        g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = 'generator_encode') + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = 'generator_sentence')
        self.d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1 = 0.5).minimize(self.d_loss, var_list = d_vars)
        self.g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1 = 0.5).minimize(self.g_loss, var_list = g_vars)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号