def _calculate_loss(self, sent):
# sent is a batch of sentences.
sent_arr = self.xp.asarray(sent, dtype=np.int32)
sent_y = self._contexts_rep(sent_arr)
sent_x = []
for i in range(sent_arr.shape[1]):
x = chainer.Variable(sent_arr[:,i])
sent_x.append(x)
accum_loss = None
for y,x in izip(sent_y, sent_x):
loss = self.loss_func(y, x)
accum_loss = accum_loss + loss if accum_loss is not None else loss
return accum_loss
评论列表
文章目录