def __call__(self, xs):
"""
xs [(w,s,p,y), ..., ]
w: word, c: char, l: length, y: label
"""
batchsize = len(xs)
if len(xs[0]) == 6:
ws, ss, ps, ls, cat_ts, dep_ts = zip(*xs)
xp = chainer.cuda.get_array_module(ws[0])
weights = [xp.array(1, 'f') for _ in xs]
else:
ws, ss, ps, ls, cat_ts, dep_ts, weights = zip(*xs)
cat_ys, dep_ys = self.forward(ws, ss, ps, ls, dep_ts if self.train else None)
cat_loss = reduce(lambda x, y: x + y,
[we * F.softmax_cross_entropy(y, t) \
for y, t, we in zip(cat_ys, cat_ts, weights)])
cat_acc = reduce(lambda x, y: x + y,
[F.accuracy(y, t, ignore_label=IGNORE) for y, t in zip(cat_ys, cat_ts)]) / batchsize
dep_loss = reduce(lambda x, y: x + y,
[we * F.softmax_cross_entropy(y, t) \
for y, t, we in zip(dep_ys, dep_ts, weights)])
dep_acc = reduce(lambda x, y: x + y,
[F.accuracy(y, t, ignore_label=IGNORE) for y, t in zip(dep_ys, dep_ts)]) / batchsize
chainer.report({
"tagging_loss": cat_loss,
"tagging_accuracy": cat_acc,
"parsing_loss": dep_loss,
"parsing_accuracy": dep_acc
}, self)
return cat_loss + dep_loss
评论列表
文章目录