python numpy where函数对数组过滤
numpy.where()函数是NumPy中用于条件处理和数组过滤的常用函数之一。它可以根据给定的条件返回一个新的数组,其中符合条件的元素被替换为一个指定的值,而不符合条件的元素则被替换为另一个指定的值。
numpy.where()函数的一般形式为:
numpy.where(condition[, x, y])
其中condition是一个布尔型数组,x和y是可选的参数,分别表示满足条件和不满足条件时的替换值。如果x和y未提供,则numpy.where()函数将返回一个包含所有满足条件的元素的新数组。
以下是一些使用numpy.where()函数的示例:
import numpy as np # 创建一个包含随机整数的数组 arr = np.random.randint(0, 10, size=(3, 3)) print("原始数组:\n", arr) # 使用 where() 函数将大于 5 的元素替换为 -1 new_arr = np.where(arr > 5, -1, arr) print("替换后的数组:\n", new_arr) # 使用 where() 函数过滤出大于 5 的元素 filtered_arr = arr[np.where(arr > 5)] print("大于 5 的元素:\n", filtered_arr)
输出:
lua
原始数组:
[[3 2 3]
[9 9 9]
[5 5 5]]
替换后的数组:
[[ 3 2 3]
[-1 -1 -1]
[ 5 5 5]]
大于 5 的元素:
[9 9 9]
在这个示例中,我们使用numpy.random.randint()函数创建了一个包含随机整数的3x3数组。然后,我们使用numpy.where()函数将大于5的元素替换为-1,使用相同的函数过滤出大于5的元素。
相关文章