disc_model.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:TextGAN 作者: AustinStoneProjects 项目源码 文件源码
def attach_cost(self, gen_model):
        # TODO: Shouldn't dynamic RNN be used here?
        # output_text, states_text = rnn.rnn(cell, inputs, initial_state=self.initial_state)
        predicted_classes_text = self.discriminate_text(self.input_data_text)
        self.loss_text = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(predicted_classes_text, np.ones((self.args.batch_size, 1), dtype=np.float32)))
        generated_wv = gen_model.generate()
        predicted_classes_wv = self.discriminate_wv(generated_wv)
        self.loss_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(predicted_classes_wv, np.zeros((self.args.batch_size, 1), dtype=np.float32)))
        self.loss = .5 * self.loss_gen + .5 * self.loss_text
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars),
            self.args.grad_clip)
        # optimize only discriminator owned variables 
        g_and_v = [(g, v) for g, v in zip(grads, tvars) if v.name.startswith('DISC')]
        optimizer = tf.train.AdamOptimizer(self.lr)
        self.train_op = optimizer.apply_gradients(g_and_v)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号