python numpy where函数对数组过滤

2023-03-09 00:00:00 函数 数组 过滤

numpy.where()函数是NumPy中用于条件处理和数组过滤的常用函数之一。它可以根据给定的条件返回一个新的数组,其中符合条件的元素被替换为一个指定的值,而不符合条件的元素则被替换为另一个指定的值。

numpy.where()函数的一般形式为:

numpy.where(condition[, x, y])

其中condition是一个布尔型数组,x和y是可选的参数,分别表示满足条件和不满足条件时的替换值。如果x和y未提供,则numpy.where()函数将返回一个包含所有满足条件的元素的新数组。

以下是一些使用numpy.where()函数的示例:

import numpy as np

# 创建一个包含随机整数的数组
arr = np.random.randint(0, 10, size=(3, 3))
print("原始数组:\n", arr)

# 使用 where() 函数将大于 5 的元素替换为 -1
new_arr = np.where(arr > 5, -1, arr)
print("替换后的数组:\n", new_arr)

# 使用 where() 函数过滤出大于 5 的元素
filtered_arr = arr[np.where(arr > 5)]
print("大于 5 的元素:\n", filtered_arr)

输出:

lua 原始数组: [[3 2 3] [9 9 9] [5 5 5]] 替换后的数组: [[ 3 2 3] [-1 -1 -1] [ 5 5 5]] 大于 5 的元素: [9 9 9]
在这个示例中,我们使用numpy.random.randint()函数创建了一个包含随机整数的3x3数组。然后,我们使用numpy.where()函数将大于5的元素替换为-1,使用相同的函数过滤出大于5的元素。

相关文章