spinn.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:PyTorchDemystified 作者: hhsecond 项目源码 文件源码
def forward(self, buffers, transitions):
        buffers = [list(torch.split(b.squeeze(1), 1, 0))
                   for b in torch.split(buffers, 1, 1)]
        stacks = [[buf[0], buf[0]] for buf in buffers]

        if hasattr(self, 'tracker'):
            self.tracker.reset_state()
        else:
            assert transitions is not None

        if transitions is not None:
            num_transitions = transitions.size(0)
            # trans_loss, trans_acc = 0, 0
        else:
            num_transitions = len(buffers[0]) * 2 - 3

        for i in range(num_transitions):
            if transitions is not None:
                trans = transitions[i]
            if hasattr(self, 'tracker'):
                tracker_states, trans_hyp = self.tracker(buffers, stacks)
                if trans_hyp is not None:
                    trans = trans_hyp.max(1)[1]
                    # if transitions is not None:
                    #     trans_loss += F.cross_entropy(trans_hyp, trans)
                    #     trans_acc += (trans_preds.data == trans.data).mean()
                    # else:
                    #     trans = trans_preds
            else:
                tracker_states = itertools.repeat(None)
            lefts, rights, trackings = [], [], []
            batch = zip(trans.data, buffers, stacks, tracker_states)
            for transition, buf, stack, tracking in batch:
                if transition == 3:  # shift
                    stack.append(buf.pop())
                elif transition == 2:  # reduce
                    rights.append(stack.pop())
                    lefts.append(stack.pop())
                    trackings.append(tracking)
            if rights:
                reduced = iter(self.reduce(lefts, rights, trackings))
                for transition, stack in zip(trans.data, stacks):
                    if transition == 2:
                        stack.append(next(reduced))
        # if trans_loss is not 0:
        return bundle([stack.pop() for stack in stacks])[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号