def baseline_search(self, input, beam_size=None):
# This is the simple greedy search
batch_size = input.size(0)
hidden_feat = self.lstm_im(input.view(1, input.size()[0], input.size()[1]))[1]
x = Variable(torch.ones(1, batch_size,).type(torch.LongTensor) * self.start, requires_grad=False).cuda() # <start>
output = []
flag = torch.ones(batch_size)
for i in range(self.nseq):
input_x = self.encoder(x.view(1, -1))
output_feature, hidden_feat = self.lstm_word(input_x, hidden_feat)
output_t = self.decoder(output_feature.view(-1, output_feature.size(2)))
output_t = F.log_softmax(output_t)
logprob, x = output_t.max(1)
output.append(x)
flag[x.cpu().eq(self.end).data] = 0
if flag.sum() == 0:
break
output = torch.stack(output, 0).squeeze().transpose(0, 1).cpu().data
return output
评论列表
文章目录