Beam.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:alpha-dimt-icmlws 作者: sotetsuk 项目源码 文件源码
def __init__(self, size, cuda=False):

        self.size = size
        self.done = False

        self.tt = torch.cuda if cuda else torch

        # The score for each translation on the beam.
        self.scores = self.tt.FloatTensor(size).zero_()

        # The backpointers at each time-step.
        self.prevKs = []

        # The outputs at each time-step.
        self.nextYs = [self.tt.LongTensor(size).fill_(onmt.Constants.PAD)]
        self.nextYs[0][0] = onmt.Constants.BOS

        # The attentions (matrix) for each time.
        self.attn = []

    # Get the outputs for the current timestep.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号