embed_id.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def forward(self, inputs):
        x, W = inputs

        if chainer.is_debug():
            if not ((0 <= x).all() and
                    (x < len(W)).all()):
                msg = 'Each `x` value need to satisfty `0 <= x < len(W)`'
                raise ValueError(msg)

        if self.ignore_label is not None:
            xp = cuda.get_array_module(*inputs)
            mask = (x == self.ignore_label)
            return xp.where(
                mask[..., None], 0, W.take(xp.where(mask, 0, x), axis=0)),

        return W.take(x, axis=0),
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号