basic.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def perform(self, node, inp, outs):
        x, axes = inp
        max, max_idx = outs
        if axes is None:
            axes = tuple(range(x.ndim))
        else:
            axes = tuple(int(ax) for ax in axes)
        max[0] = theano._asarray(numpy.max(x, axes),
                                 dtype=node.outputs[0].dtype)
        # Numpy does not support multiple axes for argmax
        # Work around
        keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes],
                                dtype='int64')
        # Not-reduced axes in front
        transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
        kept_shape = transposed_x.shape[:len(keep_axes)]
        reduced_shape = transposed_x.shape[len(keep_axes):]
        new_shape = kept_shape + (numpy.prod(reduced_shape),)
        reshaped_x = transposed_x.reshape(new_shape)

        max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1),
                                     dtype='int64')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号