model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:CNN-LSTM-Caption-Generator 作者: mosessoh 项目源码 文件源码
def generate_caption(self, session, img_feature,toSample=False):
        dp = 1
        img_template = np.zeros([self.config.batch_size, self.config.img_dim])
        img_template[0,:] = img_feature

        sent_pred = np.ones([self.config.batch_size, 1])*3591 # <SOS>
        while sent_pred[0,-1] != 3339 and (sent_pred.shape[1] - 1) < 50:
            feed = {self._sent_placeholder: sent_pred,
                    self._img_placeholder: img_template,
                    self._targets_placeholder: np.ones([self.config.batch_size,1]), # dummy variable
                    self._dropout_placeholder: dp}

            idx_next_pred = np.arange(1, self.config.batch_size + 1)*(sent_pred.shape[1] + 1) - 1

            if toSample:
                logits = session.run(self.logits, feed_dict=feed)
                next_logits = logits[idx_next_pred,:]
                raw_predicted = []
                for row_idx in range(next_logits.shape[0]):
                    idx = sample(next_logits[row_idx,:])
                    raw_predicted.append(idx)
                raw_predicted = np.array(raw_predicted)
            else:
                raw_predicted = session.run(self._predictions, feed_dict=feed)
                raw_predicted = raw_predicted[idx_next_pred]

            next_pred = np.reshape(raw_predicted, (self.config.batch_size,1))
            sent_pred = np.concatenate([sent_pred, next_pred], 1)

        predicted_sentence = ' '.join(self.index2token[idx] for idx in sent_pred[0,1:-1])
        return predicted_sentence
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号