def _choose(self, lang_hs=None, words=None, sample=False):
# get all the possible choices
choices = self.domain.generate_choices(self.context)
# concatenate the list of the hidden states into one tensor
lang_hs = lang_hs if lang_hs is not None else torch.cat(self.lang_hs)
# concatenate all the words into one tensor
words = words if words is not None else torch.cat(self.words)
# logits for each of the item
logits = self.model.generate_choice_logits(words, lang_hs, self.ctx_h)
# construct probability distribution over only the valid choices
choices_logits = []
for i in range(self.domain.selection_length()):
idxs = [self.model.item_dict.get_idx(c[i]) for c in choices]
idxs = Variable(torch.from_numpy(np.array(idxs)))
idxs = self.model.to_device(idxs)
choices_logits.append(torch.gather(logits[i], 0, idxs).unsqueeze(1))
choice_logit = torch.sum(torch.cat(choices_logits, 1), 1, keepdim=False)
# subtract the max to softmax more stable
choice_logit = choice_logit.sub(choice_logit.max().data[0])
prob = F.softmax(choice_logit)
if sample:
# sample a choice
idx = prob.multinomial().detach()
logprob = F.log_softmax(choice_logit).gather(0, idx)
else:
# take the most probably choice
_, idx = prob.max(0, keepdim=True)
logprob = None
p_agree = prob[idx.data[0]]
# Pick only your choice
return choices[idx.data[0]][:self.domain.selection_length()], logprob, p_agree.data[0]
评论列表
文章目录