lstm_gan.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lstm_gan 作者: vangaa 项目源码 文件源码
def build_model(self):
        batch_size, input_noise_size, seq_size, vocab_size = \
            self.batch_size, self.input_noise_size, \
            self.seq_size, self.vocab_size

        embedding = tf.diag(np.ones((vocab_size, ), dtype=np.float32))
        self.embedding = embedding

        input_noise = tf.placeholder(tf.float32, [batch_size, input_noise_size])
        input_noise_one_sent = tf.placeholder(tf.float32, [1, input_noise_size])
        self.input_noise = input_noise
        self.input_noise_one_sent = input_noise_one_sent

        real_sent = tf.placeholder(tf.int32, [batch_size, seq_size])
        input_sentence = tf.nn.embedding_lookup(embedding, real_sent)
        self.real_sent = real_sent

        _, gen_vars = self.build_generator(input_noise, is_train = True)
        generated_sent, _ = self.build_generator(input_noise, reuse = True)
        sent_generator, _ = self.build_generator(input_noise_one_sent, reuse = True)
        self.gen_vars = gen_vars
        self.generated_sent = generated_sent
        self.sent_generator = sent_generator

        _, disc_vars = self.build_discriminator(input_sentence, is_train = True)
        desc_decision_fake, _ = self.build_discriminator(generated_sent, reuse = True)
        disc_decision_real, _ = self.build_discriminator(input_sentence, reuse = True)
        self.disc_vars = disc_vars
        self.desc_decision_fake = desc_decision_fake
        self.disc_decision_real = disc_decision_real

        self.gen_cost = 1. - desc_decision_fake
        self.disc_cost = 1. - disc_decision_real*(1. - desc_decision_fake)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号