image_caption.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Optimization_of_image_description_metrics_using_policy_gradient_methods 作者: chenxinpeng 项目源码 文件源码
def Monte_Carlo_and_Baseline(self):
        images = tf.placeholder(tf.float32, [self.batch_size, self.feats_dim])
        images_embed = tf.matmul(images, self.encode_img_W) + self.encode_img_b

        state = self.lstm.zero_state(batch_size=self.batch_size, dtype=tf.float32)

        gen_sentences = []
        all_sample_sentences = []
        all_baselines = []

        with tf.variable_scope("LSTM"):
            output, state = self.lstm(images_embed, state)
            with tf.device("/cpu:0"):
                current_emb = tf.nn.embedding_lookup(self.Wemb, tf.ones([self.batch_size], dtype=tf.int64))

            for i in range(0, self.lstm_step):
                tf.get_variable_scope().reuse_variables()

                output, state = self.lstm(current_emb, state)
                logit_words = tf.matmul(output, self.embed_word_W) + self.embed_word_b
                max_prob_word = tf.argmax(logit_words, 1)
                with tf.device("/cpu:0"):
                    current_emb = tf.nn.embedding_lookup(self.Wemb, max_prob_word)
                    #current_emb = tf.expand_dims(current_emb, 0)
                gen_sentences.append(max_prob_word)

                # compute Q for gt with K Monte Carlo rollouts
                if i < self.lstm_step-1:
                    num_sample = self.lstm_step - 1 - i
                    sample_sentences = []
                    for idx_sample in range(num_sample):
                        sample = tf.multinomial(logit_words, 3)
                        sample_sentences.append(sample)
                    all_sample_sentences.append(sample_sentences)
                # compute eatimated baseline
                baseline = tf.nn.relu(tf.matmul(state[1], self.baseline_MLP_W) + self.baseline_MLP_b)
                all_baselines.append(baseline)

        return images, gen_sentences, all_sample_sentences, all_baselines
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号