def train(self):
self.train_op = self.optim.minimize(self.loss, global_step=self.global_step)
self.writer = tf.train.SummaryWriter("./logs/D_pretrained", self.sess.graph)
self.summary_op = tf.merge_all_summaries()
tf.initialize_all_variables().run()
self.saver = tf.train.Saver(var_list=self.D_params_dict, max_to_keep=self.max_to_keep)
count = 0
for idx in range(self.max_iter//3000):
self.save(self.checkpoint_dir, count)
self.evaluate('test', count)
self.evaluate('train', count)
for k in tqdm(range(3000)):
right_images, right_text, _ = self.dataset.sequential_sample(self.batch_size)
right_length = np.sum((right_text!=self.NOT)+0, 1)
fake_images, fake_text, _ = self.negative_dataset.sequential_sample(self.batch_size)
fake_length = np.sum((fake_text!=self.NOT)+0, 1)
wrong_text = self.dataset.get_wrong_text(self.batch_size)
wrong_length = np.sum((wrong_text!=self.NOT)+0, 1)
feed_dict = {self.right_images:right_images, self.right_text:right_text, self.right_length:right_length,
self.fake_images:fake_images, self.fake_text:fake_text, self.fake_length:fake_length,
self.wrong_images:right_images, self.wrong_text:wrong_text, self.wrong_length:wrong_length}
_, loss, summary_str = self.sess.run([self.train_op, self.loss, self.summary_op], feed_dict)
self.writer.add_summary(summary_str, count)
count += 1
pretrain_LSTM_D.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录