def generate_caption(self, session, img_feature,toSample=False):
dp = 1
img_template = np.zeros([self.config.batch_size, self.config.img_dim])
img_template[0,:] = img_feature
sent_pred = np.ones([self.config.batch_size, 1])*3591 # <SOS>
while sent_pred[0,-1] != 3339 and (sent_pred.shape[1] - 1) < 50:
feed = {self._sent_placeholder: sent_pred,
self._img_placeholder: img_template,
self._targets_placeholder: np.ones([self.config.batch_size,1]), # dummy variable
self._dropout_placeholder: dp}
idx_next_pred = np.arange(1, self.config.batch_size + 1)*(sent_pred.shape[1] + 1) - 1
if toSample:
logits = session.run(self.logits, feed_dict=feed)
next_logits = logits[idx_next_pred,:]
raw_predicted = []
for row_idx in range(next_logits.shape[0]):
idx = sample(next_logits[row_idx,:])
raw_predicted.append(idx)
raw_predicted = np.array(raw_predicted)
else:
raw_predicted = session.run(self._predictions, feed_dict=feed)
raw_predicted = raw_predicted[idx_next_pred]
next_pred = np.reshape(raw_predicted, (self.config.batch_size,1))
sent_pred = np.concatenate([sent_pred, next_pred], 1)
predicted_sentence = ' '.join(self.index2token[idx] for idx in sent_pred[0,1:-1])
return predicted_sentence
评论列表
文章目录