Python – 如何在NumPy数组中获得N个最大值的索引?

NumPy提出了一种获取数组最大值索引的方法np.argmax

我想要一个类似的东西,但返回N最大值的索引。

例如,如果我有一个数组,[1, 3, 2, 4, 5]function(array, n=3)将返回的索引[4, 3, 1]相对应的元素[5, 4, 3]


 [1]: import numpy as np

In [2]: arr = np.array([1, 3, 2, 4, 5])

In [3]: arr.argsort()[-3:][::-1]
Out[3]: array([4, 3, 1])

这涉及到完整的数组。我想知道是否numpy提供了一种内置的方式来进行局部排序; 到目前为止,我还没有找到一个。

如果这个解决方案太慢(特别是对于小型n),那么在Cython中编写代码可能是值得的。


较新的NumPy版本(1.8及更高版本)具有此功能argpartition。要获得四个最大元素的索引,请执行

>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> a
array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> ind = np.argpartition(a, -4)[-4:]
>>> ind
array([1, 5, 8, 0])
>>> a[ind]
array([4, 9, 6, 9])

argsort此不同,此函数在最坏的情况下以线性时间运行,但返回的索引未排序,从评估结果可以看出a[ind]。如果您也需要,请在之后对其进行排序:

>>> ind[np.argsort(a[ind])]
array([1, 8, 5, 0])

以这种方式按排序顺序获取top- k元素需要O(n + k log k)时间。


更简单:

idx = (-arr).argsort()[:n]

其中n是最大值的数量。


使用:

>>> import heapq
>>> import numpy
>>> a = numpy.array([1, 3, 2, 4, 5])
>>> heapq.nlargest(3, range(len(a)), a.take)
[4, 3, 1]

对于常规Python列表:

>>> a = [1, 3, 2, 4, 5]
>>> heapq.nlargest(3, range(len(a)), a.__getitem__)
[4, 3, 1]

如果您使用Python 2,请使用xrange而不是range


如果您正在使用多维数组,那么您将需要展平并解开索引:

def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

例如:

>>> xs = np.sin(np.arange(9)).reshape((3, 3))
>>> xs
array([[ 0.        ,  0.84147098,  0.90929743],
       [ 0.14112001, -0.7568025 , -0.95892427],
       [-0.2794155 ,  0.6569866 ,  0.98935825]])
>>> largest_indices(xs, 3)
(array([2, 0, 0]), array([2, 2, 1]))
>>> xs[largest_indices(xs, 3)]
array([ 0.98935825,  0.90929743,  0.84147098])

添加评论

友情链接:蝴蝶教程