model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:CNN-LSTM-Caption-Generator 作者: mosessoh 项目源码 文件源码
def run_epoch(self, session, train_op):
        total_steps = sum(1 for x in train_data_iterator(self.train_captions, self.train_caption_id2sentence, self.train_caption_id2image_id, self.train_image_id2feature, self.config))
        total_loss = []
        if not train_op:
            train_op = tf.no_op()
        start = time.time()

        for step, (sentences, images, targets) in enumerate(train_data_iterator(self.train_captions, self.train_caption_id2sentence, self.train_caption_id2image_id, self.train_image_id2feature, self.config)):

            feed = {self._sent_placeholder: sentences,
                    self._img_placeholder: images,
                    self._targets_placeholder: targets,
                    self._dropout_placeholder: self.config.keep_prob}
            loss, _ = session.run([self.loss, train_op], feed_dict=feed)
            total_loss.append(loss)

            if (step % 50) == 0:
                print '%d/%d: loss = %.2f time elapsed = %d' % (step, total_steps, np.mean(total_loss) , time.time() - start)

        print 'Total time: %ds' % (time.time() - start) 
        return total_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号