def forward(self, buffers, transitions):
buffers = [list(torch.split(b.squeeze(1), 1, 0))
for b in torch.split(buffers, 1, 1)]
stacks = [[buf[0], buf[0]] for buf in buffers]
if hasattr(self, 'tracker'):
self.tracker.reset_state()
else:
assert transitions is not None
if transitions is not None:
num_transitions = transitions.size(0)
# trans_loss, trans_acc = 0, 0
else:
num_transitions = len(buffers[0]) * 2 - 3
for i in range(num_transitions):
if transitions is not None:
trans = transitions[i]
if hasattr(self, 'tracker'):
tracker_states, trans_hyp = self.tracker(buffers, stacks)
if trans_hyp is not None:
trans = trans_hyp.max(1)[1]
# if transitions is not None:
# trans_loss += F.cross_entropy(trans_hyp, trans)
# trans_acc += (trans_preds.data == trans.data).mean()
# else:
# trans = trans_preds
else:
tracker_states = itertools.repeat(None)
lefts, rights, trackings = [], [], []
batch = zip(trans.data, buffers, stacks, tracker_states)
for transition, buf, stack, tracking in batch:
if transition == 3: # shift
stack.append(buf.pop())
elif transition == 2: # reduce
rights.append(stack.pop())
lefts.append(stack.pop())
trackings.append(tracking)
if rights:
reduced = iter(self.reduce(lefts, rights, trackings))
for transition, stack in zip(trans.data, stacks):
if transition == 2:
stack.append(next(reduced))
# if trans_loss is not 0:
return bundle([stack.pop() for stack in stacks])[0]
评论列表
文章目录