Language_Model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def baseline_search(self, input, beam_size=None):
        # This is the simple greedy search
        batch_size = input.size(0)
        hidden_feat = self.lstm_im(input.view(1, input.size()[0], input.size()[1]))[1]
        x = Variable(torch.ones(1, batch_size,).type(torch.LongTensor) * self.start, requires_grad=False).cuda() # <start>
        output = []
        flag = torch.ones(batch_size)
        for i in range(self.nseq):
            input_x = self.encoder(x.view(1, -1))
            output_feature, hidden_feat = self.lstm_word(input_x, hidden_feat)
            output_t = self.decoder(output_feature.view(-1, output_feature.size(2)))
            output_t = F.log_softmax(output_t)
            logprob, x = output_t.max(1)
            output.append(x)
            flag[x.cpu().eq(self.end).data] = 0
            if flag.sum() == 0:
                break
        output = torch.stack(output, 0).squeeze().transpose(0, 1).cpu().data
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号