def perform(self, node, inp, outs):
x, axes = inp
max, max_idx = outs
if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axes)
max[0] = theano._asarray(numpy.max(x, axes),
dtype=node.outputs[0].dtype)
# Numpy does not support multiple axes for argmax
# Work around
keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes],
dtype='int64')
# Not-reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
kept_shape = transposed_x.shape[:len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes):]
new_shape = kept_shape + (numpy.prod(reduced_shape),)
reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1),
dtype='int64')
评论列表
文章目录