pretrain_LSTM_D.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:show-adapt-and-tell 作者: tsenghungchen 项目源码 文件源码
def train(self):

    self.train_op = self.optim.minimize(self.loss, global_step=self.global_step)
        self.writer = tf.train.SummaryWriter("./logs/D_pretrained", self.sess.graph)
    self.summary_op = tf.merge_all_summaries()
        tf.initialize_all_variables().run()
        self.saver = tf.train.Saver(var_list=self.D_params_dict, max_to_keep=self.max_to_keep)
        count = 0
    for idx in range(self.max_iter//3000):
            self.save(self.checkpoint_dir, count)
            self.evaluate('test', count)
        self.evaluate('train', count)
            for k in tqdm(range(3000)):
        right_images, right_text, _ = self.dataset.sequential_sample(self.batch_size)
        right_length = np.sum((right_text!=self.NOT)+0, 1)
        fake_images, fake_text, _ = self.negative_dataset.sequential_sample(self.batch_size)
        fake_length = np.sum((fake_text!=self.NOT)+0, 1)
        wrong_text = self.dataset.get_wrong_text(self.batch_size)
        wrong_length = np.sum((wrong_text!=self.NOT)+0, 1)
        feed_dict = {self.right_images:right_images, self.right_text:right_text, self.right_length:right_length, 
                self.fake_images:fake_images, self.fake_text:fake_text, self.fake_length:fake_length, 
                self.wrong_images:right_images, self.wrong_text:wrong_text, self.wrong_length:wrong_length}
        _, loss, summary_str = self.sess.run([self.train_op, self.loss, self.summary_op], feed_dict)
        self.writer.add_summary(summary_str, count)
                count += 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号