检索 NumPy 数组中多个最大值的索引
NumPy 提供了一个方便的 np.argmax 函数来检索 NumPy 数组中最大值的索引一个数组。但是,如果您需要找到前 N 个最大值的索引怎么办?
解决方案
最近的 NumPy 版本(1.8 及更高版本)为此引入了 argpartition 函数目的。要获取前 N 个元素的索引,请按照以下步骤操作:
import numpy as np # Original array a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0]) # Find indices of top N elements (N = 4 in this case) ind = np.argpartition(a, -4)[-4:] # Extract top N elements top4 = a[ind] # Print indices and top N elements print("Indices:", ind) print("Top 4 elements:", top4)
说明
np.argpartition 对数组进行部分排序,将其分为两个子数组:第一个子数组包含前 N 个元素(在本例中为最大的 4 个元素),第二个子数组包含其余元素。返回的数组 ind 包含第一个子数组中元素的索引。
此示例中的输出将是:
Indices: [1 5 8 0] Top 4 elements: [4 9 6 9]
优化
如果还需要排序索引,可以单独排序:
sorted_ind = ind[np.argsort(a[ind])]
这一步需要O(k log k) 时间,其中 k 是要检索的顶部元素的数量。总的来说,这种方法的时间复杂度为 O(n k log k),对于大型数组和中等 k 值非常有效。
以上是如何高效查找 NumPy 数组中多个最大值的索引?的详细内容。更多信息请关注PHP中文网其他相关文章!