select_item.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def forward(self, inputs):
        x, t = inputs
        if chainer.is_debug():
            if not ((0 <= t).all() and
                    (t < x.shape[1]).all()):
                msg = 'Each label `t` need to satisfty `0 <= t < x.shape[1]`'
                raise ValueError(msg)

        xp = cuda.get_array_module(x)
        if xp is numpy:
            # This code is equivalent to `t.choose(x.T)`, but `numpy.choose`
            # does not work when `x.shape[1] > 32`.
            return x[six.moves.range(t.size), t],
        else:
            y = cuda.elementwise(
                'S t, raw T x',
                'T y',
                'int ind[] = {i, t}; y = x[ind];',
                'getitem_fwd'
            )(t, x)
            return y,
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号