def forward(self, inputs):
x, W = inputs
if chainer.is_debug():
if not ((0 <= x).all() and
(x < len(W)).all()):
msg = 'Each `x` value need to satisfty `0 <= x < len(W)`'
raise ValueError(msg)
if self.ignore_label is not None:
xp = cuda.get_array_module(*inputs)
mask = (x == self.ignore_label)
return xp.where(
mask[..., None], 0, W.take(xp.where(mask, 0, x), axis=0)),
return W.take(x, axis=0),
评论列表
文章目录