pandas 中的 for 循环真的很糟糕吗?我什么时候应该关心?
问题描述
for
循环真的不好"吗?如果不是,在什么情况下它们会比使用更传统的矢量化"方法更好?1
我熟悉矢量化"的概念,以及 pandas 如何使用矢量化技术来加快计算速度.矢量化函数在整个系列或 DataFrame 上广播操作,以实现比传统迭代数据更快的速度.
但是,我很惊讶地看到许多代码(包括来自 Stack Overflow 上的答案)提供了涉及使用 for
循环和列表推导来循环数据的问题的解决方案.文档和 API 说循环是坏的",并且不应该永远"迭代数组、系列或 DataFrame.那么,为什么我有时会看到用户提出基于循环的解决方案?
1 - 虽然这个问题听起来有点宽泛,但事实是在非常特殊的情况下,for
循环通常比传统的数据迭代更好.这篇文章旨在为后代记录这一点.
TLDR;不,for
循环并不是一概而论的坏",至少,并非总是如此.说某些向量化操作比迭代慢可能更准确,而不是说迭代比某些向量化操作快.了解何时以及为什么是从代码中获得最大性能的关键.简而言之,这些是值得考虑替代矢量化 pandas 函数的情况:
- 当您的数据较小时(...取决于您在做什么),
- 当处理
object
/mixed dtypes - 使用
str
/regex 访问器函数时
让我们逐一检查这些情况.
小数据上的迭代与向量化
Pandas 在其 API 设计中遵循
对于中等大小的 N,列表解析优于 query
,甚至对于微小 N 的性能优于向量化不等于比较.不幸的是,列表解析是线性扩展的,因此对于较大的 N,它并没有提供太多的性能提升N.
注意
值得一提的是,列表理解的大部分好处来自不必担心索引对齐,但这意味着如果您的代码依赖于索引对齐,这将打破.在某些情况下,向量化操作在底层的 NumPy 数组可以被认为是引入了最好的两个世界",允许矢量化没有所有不必要的 pandas 函数开销.这意味着您可以将上面的操作重写为
df[df.A.values != df.B.values]
优于 pandas 和列表理解等价物:
NumPy 矢量化超出了本文的范围,但如果性能很重要,它绝对值得考虑.
价值计算
再举一个例子 - 这一次,使用另一个比 for 循环更快 的普通 python 构造 -
结果更明显,Counter
在更大范围的小 N (~3500) 上胜过两种矢量化方法.
注意
更多琐事(礼貌@user2357112).Counter
使用
Numba 将循环 python 代码的 JIT 编译为非常强大的矢量化代码.了解如何让 numba 发挥作用涉及到学习曲线.
混合/object
dtypes 的操作
基于字符串的比较
回顾第一节中的过滤示例,如果要比较的列是字符串怎么办?考虑上面相同的 3 个函数,但输入 DataFrame 转换为字符串.
# 带有字符串值比较的布尔索引.df[df.A != df.B] # 矢量化 !=df.query('A != B') # 查询 (numexpr)df[[x != y for x, y in zip(df.A, df.B)]] # list comp
那么,发生了什么变化?这里需要注意的是,字符串操作本质上很难向量化.Pandas 将字符串视为对象,对对象的所有操作都会退回到缓慢、循环的实现.
现在,由于这种循环实现被上述所有开销所包围,因此这些解决方案之间存在恒定的量级差异,即使它们的规模相同.
对于可变/复杂对象的操作,没有可比性.列表理解优于所有涉及字典和列表的操作.
按键访问字典值
以下是从字典列中提取值的两个操作的时间安排:map
和列表推导.设置在附录中,标题为代码片段".
# 字典值提取.ser.map(operator.itemgetter('value')) # 地图pd.Series([x.get('value') for x in ser]) # 列表推导
位置列表索引
从列列表中提取第 0 个元素的 3 次操作的计时(处理异常),
列表扁平化
最后一个例子是扁平化列表.这是另一个常见的问题,在这里展示了纯 python 的强大.
# 嵌套列表展平.pd.DataFrame(ser.tolist()).stack().reset_index(drop=True) # 堆栈pd.Series(list(chain.from_iterable(ser.tolist()))) # itertools.chainpd.Series([y for x in ser for y in x]) # 嵌套列表组合
更多示例
完全披露 - 我是下面列出的这些帖子的作者(部分或全部).
使用 pandas 快速删除标点符号
两个 pandas 列的字符串连接
从in中的字符串中删除不需要的部分一列
替换数据框中除最后一个字符之外的所有字符
结论
如上例所示,迭代在处理小行数据帧、混合数据类型和正则表达式时大放异彩.
您获得的加速取决于您的数据和您的问题,因此您的里程可能会有所不同.最好的办法是仔细运行测试,看看付出的努力是否值得.
矢量化"函数因其简单性和可读性而大放异彩,因此如果性能不重要,您绝对应该更喜欢这些.
另一方面,某些字符串操作处理有利于使用 NumPy 的约束.以下是 NumPy 向量化优于 python 的两个示例:
以更快、更有效的方式创建具有增量值的新列 - Divakar 的回答
使用 pandas 快速删除标点符号 - Paul Panzer 的回答
此外,有时仅通过 .values
对底层数组进行操作,而不是对 Series 或 DataFrames 进行操作,可以为大多数常见场景提供足够健康的加速(参见 注意 在上面的数值比较部分).因此,例如 df[df.A.values != df.B.values]
会比 df[df.A != df.B]
显示即时性能提升.使用 .values
可能并不适用于所有情况,但它是一个有用的技巧.
如上所述,由您决定这些解决方案是否值得实施.
附录:代码片段
导入性能图进口经营者将熊猫导入为 pd将 numpy 导入为 np重新进口从集合导入计数器从 itertools 导入链
<!-->
# 带有数值比较的布尔索引.perfplot.show(setup=lambda n: pd.DataFrame(np.random.choice(1000, (n, 2)), columns=['A','B']),内核=[lambda df: df[df.A != df.B],lambda df: df.query('A != B'),lambda df: df[[x != y for x, y in zip(df.A, df.B)]],lambda df: df[get_mask(df.A.values, df.B.values)]],标签=['矢量化!=','查询(numexpr)','list comp','numba'],n_range=[2**k for k in range(0, 15)],xlabel='N')
<!-->
# 值计数比较.perfplot.show(setup=lambda n: pd.Series(np.random.choice(1000, n)),内核=[lambda ser: ser.value_counts(sort=False).to_dict(),lambda ser: dict(zip(*np.unique(ser, return_counts=True))),lambda ser:计数器(ser),],标签=['value_counts', 'np.unique', 'Counter'],n_range=[2**k for k in range(0, 15)],xlabel='N',equal_check=lambda x, y: dict(x) == dict(y))
<!-->
# 带有字符串值比较的布尔索引.perfplot.show(setup=lambda n: pd.DataFrame(np.random.choice(1000, (n, 2)), columns=['A','B'], dtype=str),内核=[lambda df: df[df.A != df.B],lambda df: df.query('A != B'),lambda df: df[[x != y for x, y in zip(df.A, df.B)]],],标签=['向量化!=','查询(numexpr)','列表comp'],n_range=[2**k for k in range(0, 15)],xlabel='N',平等检查=无)
<!-->
# 字典值提取.ser1 = pd.Series([{'key': 'abc', 'value': 123}, {'key': 'xyz', 'value': 456}])perfplot.show(setup=lambda n: pd.concat([ser1] * n, ignore_index=True),内核=[lambda ser: ser.map(operator.itemgetter('value')),lambda ser: pd.Series([x.get('value') for x in ser]),],标签=['地图','列表理解'],n_range=[2**k for k in range(0, 15)],xlabel='N',平等检查=无)
<!-->
# 列出位置索引.ser2 = pd.Series([['a', 'b', 'c'], [1, 2], []])perfplot.show(setup=lambda n: pd.concat([ser2] * n, ignore_index=True),内核=[lambda ser: ser.map(get_0th),lambda ser: ser.str[0],lambda ser: pd.Series([x[0] if len(x) > 0 else np.nan for x in ser]),lambda ser: pd.Series([get_0th(x) for x in ser]),],标签=['map', 'str accessor', 'list comprehension', 'list comp safe'],n_range=[2**k for k in range(0, 15)],xlabel='N',平等检查=无)
<!-->
# 嵌套列表展平.ser3 = pd.Series([['a', 'b', 'c'], ['d', 'e'], ['f', 'g']])perfplot.show(setup=lambda n: pd.concat([ser2] * n, ignore_index=True),内核=[lambda ser: pd.DataFrame(ser.tolist()).stack().reset_index(drop=True),lambda ser: pd.Series(list(chain.from_iterable(ser.tolist()))),lambda ser: pd.Series([y for x in ser for y in x]),],标签=['stack', 'itertools.chain', '嵌套列表 comp'],n_range=[2**k for k in range(0, 15)],xlabel='N',平等检查=无)
<!- _>
# 提取字符串.ser4 = pd.Series(['foo xyz', 'test A1234', 'D3345 xtz'])perfplot.show(setup=lambda n: pd.concat([ser4] * n, ignore_index=True),内核=[lambda ser: ser.str.extract(r'(?<=[A-Z])(d{4})', expand=False),lambda ser: pd.Series([matcher(x) for x in ser])],标签=['str.extract', '列表理解'],n_range=[2**k for k in range(0, 15)],xlabel='N',平等检查=无)
Are for
loops really "bad"? If not, in what situation(s) would they be better than using a more conventional "vectorized" approach?1
I am familiar with the concept of "vectorization", and how pandas employs vectorized techniques to speed up computation. Vectorized functions broadcast operations over the entire series or DataFrame to achieve speedups much greater than conventionally iterating over the data.
However, I am quite surprised to see a lot of code (including from answers on Stack Overflow) offering solutions to problems that involve looping through data using for
loops and list comprehensions. The documentation and API say that loops are "bad", and that one should "never" iterate over arrays, series, or DataFrames. So, how come I sometimes see users suggesting loop-based solutions?
1 - While it is true that the question sounds somewhat broad, the truth is that there are very specific situations when for
loops are usually better than conventionally iterating over data. This post aims to capture this for posterity.
TLDR; No, for
loops are not blanket "bad", at least, not always. It is probably more accurate to say that some vectorized operations are slower than iterating, versus saying that iteration is faster than some vectorized operations. Knowing when and why is key to getting the most performance out of your code. In a nutshell, these are the situations where it is worth considering an alternative to vectorized pandas functions:
- When your data is small (...depending on what you're doing),
- When dealing with
object
/mixed dtypes - When using the
str
/regex accessor functions
Let's examine these situations individually.
Iteration v/s Vectorization on Small Data
Pandas follows a "Convention Over Configuration" approach in its API design. This means that the same API has been fitted to cater to a broad range of data and use cases.
When a pandas function is called, the following things (among others) must internally be handled by the function, to ensure working
- Index/axis alignment
- Handling mixed datatypes
- Handling missing data
Almost every function will have to deal with these to varying extents, and this presents an overhead. The overhead is less for numeric functions (for example, Series.add
), while it is more pronounced for string functions (for example, Series.str.replace
).
for
loops, on the other hand, are faster then you think. What's even better is list comprehensions (which create lists through for
loops) are even faster as they are optimized iterative mechanisms for list creation.
List comprehensions follow the pattern
[f(x) for x in seq]
Where seq
is a pandas series or DataFrame column. Or, when operating over multiple columns,
[f(x, y) for x, y in zip(seq1, seq2)]
Where seq1
and seq2
are columns.
Numeric Comparison
Consider a simple boolean indexing operation. The list comprehension method has been timed against Series.ne
(!=
) and query
. Here are the functions:
# Boolean indexing with Numeric value comparison.
df[df.A != df.B] # vectorized !=
df.query('A != B') # query (numexpr)
df[[x != y for x, y in zip(df.A, df.B)]] # list comp
For simplicity, I have used the perfplot
package to run all the timeit tests in this post. The timings for the operations above are below:
The list comprehension outperforms query
for moderately sized N, and even outperforms the vectorized not equals comparison for tiny N. Unfortunately, the list comprehension scales linearly, so it does not offer much performance gain for larger N.
Note
It is worth mentioning that much of the benefit of list comprehension come from not having to worry about the index alignment, but this means that if your code is dependent on indexing alignment, this will break. In some cases, vectorised operations over the underlying NumPy arrays can be considered as bringing in the "best of both worlds", allowing for vectorisation without all the unneeded overhead of the pandas functions. This means that you can rewrite the operation above asdf[df.A.values != df.B.values]
Which outperforms both the pandas and list comprehension equivalents:
NumPy vectorization is out of the scope of this post, but it is definitely worth considering, if performance matters.
Value Counts
Taking another example - this time, with another vanilla python construct that is faster than a for loop - collections.Counter
. A common requirement is to compute the value counts and return the result as a dictionary. This is done with value_counts
, np.unique
, and Counter
:
# Value Counts comparison.
ser.value_counts(sort=False).to_dict() # value_counts
dict(zip(*np.unique(ser, return_counts=True))) # np.unique
Counter(ser) # Counter
The results are more pronounced, Counter
wins out over both vectorized methods for a larger range of small N (~3500).
Note
More trivia (courtesy @user2357112). TheCounter
is implemented with a C accelerator, so while it still has to work with python objects instead of the underlying C datatypes, it is still faster than afor
loop. Python power!
Of course, the take away from here is that the performance depends on your data and use case. The point of these examples is to convince you not to rule out these solutions as legitimate options. If these still don't give you the performance you need, there is always cython and numba. Let's add this test into the mix.
from numba import njit, prange
@njit(parallel=True)
def get_mask(x, y):
result = [False] * len(x)
for i in prange(len(x)):
result[i] = x[i] != y[i]
return np.array(result)
df[get_mask(df.A.values, df.B.values)] # numba
Numba offers JIT compilation of loopy python code to very powerful vectorized code. Understanding how to make numba work involves a learning curve.
Operations with Mixed/object
dtypes
String-based Comparison
Revisiting the filtering example from the first section, what if the columns being compared are strings? Consider the same 3 functions above, but with the input DataFrame cast to string.
# Boolean indexing with string value comparison.
df[df.A != df.B] # vectorized !=
df.query('A != B') # query (numexpr)
df[[x != y for x, y in zip(df.A, df.B)]] # list comp
So, what changed? The thing to note here is that string operations are inherently difficult to vectorize. Pandas treats strings as objects, and all operations on objects fall back to a slow, loopy implementation.
Now, because this loopy implementation is surrounded by all the overhead mentioned above, there is a constant magnitude difference between these solutions, even though they scale the same.
When it comes to operations on mutable/complex objects, there is no comparison. List comprehension outperforms all operations involving dicts and lists.
Accessing Dictionary Value(s) by Key
Here are timings for two operations that extract a value from a column of dictionaries: map
and the list comprehension. The setup is in the Appendix, under the heading "Code Snippets".
# Dictionary value extraction.
ser.map(operator.itemgetter('value')) # map
pd.Series([x.get('value') for x in ser]) # list comprehension
Positional List Indexing
Timings for 3 operations that extract the 0th element from a list of columns (handling exceptions), map
, str.get
accessor method, and the list comprehension:
# List positional indexing.
def get_0th(lst):
try:
return lst[0]
# Handle empty lists and NaNs gracefully.
except (IndexError, TypeError):
return np.nan
ser.map(get_0th) # map
ser.str[0] # str accessor
pd.Series([x[0] if len(x) > 0 else np.nan for x in ser]) # list comp
pd.Series([get_0th(x) for x in ser]) # list comp safe
Note
If the index matters, you would want to do:pd.Series([...], index=ser.index)
When reconstructing the series.
List Flattening
A final example is flattening lists. This is another common problem, and demonstrates just how powerful pure python is here.
# Nested list flattening.
pd.DataFrame(ser.tolist()).stack().reset_index(drop=True) # stack
pd.Series(list(chain.from_iterable(ser.tolist()))) # itertools.chain
pd.Series([y for x in ser for y in x]) # nested list comp
Both itertools.chain.from_iterable
and the nested list comprehension are pure python constructs, and scale much better than the stack
solution.
These timings are a strong indication of the fact that pandas is not equipped to work with mixed dtypes, and that you should probably refrain from using it to do so. Wherever possible, data should be present as scalar values (ints/floats/strings) in separate columns.
Lastly, the applicability of these solutions depend widely on your data. So, the best thing to do would be to test these operations on your data before deciding what to go with. Notice how I have not timed apply
on these solutions, because it would skew the graph (yes, it's that slow).
Regex Operations, and .str
Accessor Methods
Pandas can apply regex operations such as str.contains
, str.extract
, and str.extractall
, as well as other "vectorized" string operations (such as str.split
, str.find
, str.translate
, and so on) on string columns. These functions are slower than list comprehensions, and are meant to be more convenience functions than anything else.
It is usually much faster to pre-compile a regex pattern and iterate over your data with re.compile
(also see Is it worth using Python's re.compile?). The list comp equivalent to str.contains
looks something like this:
p = re.compile(...)
ser2 = pd.Series([x for x in ser if p.search(x)])
Or,
ser2 = ser[[bool(p.search(x)) for x in ser]]
If you need to handle NaNs, you can do something like
ser[[bool(p.search(x)) if pd.notnull(x) else False for x in ser]]
The list comp equivalent to str.extract
(without groups) will look something like:
df['col2'] = [p.search(x).group(0) for x in df['col']]
If you need to handle no-matches and NaNs, you can use a custom function (still faster!):
def matcher(x):
m = p.search(str(x))
if m:
return m.group(0)
return np.nan
df['col2'] = [matcher(x) for x in df['col']]
The matcher
function is very extensible. It can be fitted to return a list for each capture group, as needed. Just extract query the group
or groups
attribute of the matcher object.
For str.extractall
, change p.search
to p.findall
.
String Extraction
Consider a simple filtering operation. The idea is to extract 4 digits if it is preceded by an upper case letter.
# Extracting strings.
p = re.compile(r'(?<=[A-Z])(d{4})')
def matcher(x):
m = p.search(x)
if m:
return m.group(0)
return np.nan
ser.str.extract(r'(?<=[A-Z])(d{4})', expand=False) # str.extract
pd.Series([matcher(x) for x in ser]) # list comprehension
More Examples
Full disclosure - I am the author (in part or whole) of these posts listed below.
Fast punctuation removal with pandas
String concatenation of two pandas columns
Remove unwanted parts from strings in a column
Replace all but the last occurrence of a character in a dataframe
Conclusion
As shown from the examples above, iteration shines when working with small rows of DataFrames, mixed datatypes, and regular expressions.
The speedup you get depends on your data and your problem, so your mileage may vary. The best thing to do is to carefully run tests and see if the payout is worth the effort.
The "vectorized" functions shine in their simplicity and readability, so if performance is not critical, you should definitely prefer those.
Another side note, certain string operations deal with constraints that favour the use of NumPy. Here are two examples where careful NumPy vectorization outperforms python:
Create new column with incremental values in a faster and efficient way - Answer by Divakar
Fast punctuation removal with pandas - Answer by Paul Panzer
Additionally, sometimes just operating on the underlying arrays via .values
as opposed to on the Series or DataFrames can offer a healthy enough speedup for most usual scenarios (see the Note in the Numeric Comparison section above). So, for example df[df.A.values != df.B.values]
would show instant performance boosts over df[df.A != df.B]
. Using .values
may not be appropriate in every situation, but it is a useful hack to know.
As mentioned above, it's up to you to decide whether these solutions are worth the trouble of implementing.
Appendix: Code Snippets
import perfplot
import operator
import pandas as pd
import numpy as np
import re
from collections import Counter
from itertools import chain
<!- ->
# Boolean indexing with Numeric value comparison.
perfplot.show(
setup=lambda n: pd.DataFrame(np.random.choice(1000, (n, 2)), columns=['A','B']),
kernels=[
lambda df: df[df.A != df.B],
lambda df: df.query('A != B'),
lambda df: df[[x != y for x, y in zip(df.A, df.B)]],
lambda df: df[get_mask(df.A.values, df.B.values)]
],
labels=['vectorized !=', 'query (numexpr)', 'list comp', 'numba'],
n_range=[2**k for k in range(0, 15)],
xlabel='N'
)
<!- ->
# Value Counts comparison.
perfplot.show(
setup=lambda n: pd.Series(np.random.choice(1000, n)),
kernels=[
lambda ser: ser.value_counts(sort=False).to_dict(),
lambda ser: dict(zip(*np.unique(ser, return_counts=True))),
lambda ser: Counter(ser),
],
labels=['value_counts', 'np.unique', 'Counter'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=lambda x, y: dict(x) == dict(y)
)
<!- ->
# Boolean indexing with string value comparison.
perfplot.show(
setup=lambda n: pd.DataFrame(np.random.choice(1000, (n, 2)), columns=['A','B'], dtype=str),
kernels=[
lambda df: df[df.A != df.B],
lambda df: df.query('A != B'),
lambda df: df[[x != y for x, y in zip(df.A, df.B)]],
],
labels=['vectorized !=', 'query (numexpr)', 'list comp'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
<!- ->
# Dictionary value extraction.
ser1 = pd.Series([{'key': 'abc', 'value': 123}, {'key': 'xyz', 'value': 456}])
perfplot.show(
setup=lambda n: pd.concat([ser1] * n, ignore_index=True),
kernels=[
lambda ser: ser.map(operator.itemgetter('value')),
lambda ser: pd.Series([x.get('value') for x in ser]),
],
labels=['map', 'list comprehension'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
<!- ->
# List positional indexing.
ser2 = pd.Series([['a', 'b', 'c'], [1, 2], []])
perfplot.show(
setup=lambda n: pd.concat([ser2] * n, ignore_index=True),
kernels=[
lambda ser: ser.map(get_0th),
lambda ser: ser.str[0],
lambda ser: pd.Series([x[0] if len(x) > 0 else np.nan for x in ser]),
lambda ser: pd.Series([get_0th(x) for x in ser]),
],
labels=['map', 'str accessor', 'list comprehension', 'list comp safe'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
<!- ->
# Nested list flattening.
ser3 = pd.Series([['a', 'b', 'c'], ['d', 'e'], ['f', 'g']])
perfplot.show(
setup=lambda n: pd.concat([ser2] * n, ignore_index=True),
kernels=[
lambda ser: pd.DataFrame(ser.tolist()).stack().reset_index(drop=True),
lambda ser: pd.Series(list(chain.from_iterable(ser.tolist()))),
lambda ser: pd.Series([y for x in ser for y in x]),
],
labels=['stack', 'itertools.chain', 'nested list comp'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
<!- _>
# Extracting strings.
ser4 = pd.Series(['foo xyz', 'test A1234', 'D3345 xtz'])
perfplot.show(
setup=lambda n: pd.concat([ser4] * n, ignore_index=True),
kernels=[
lambda ser: ser.str.extract(r'(?<=[A-Z])(d{4})', expand=False),
lambda ser: pd.Series([matcher(x) for x in ser])
],
labels=['str.extract', 'list comprehension'],
n_range=[2**k for k in range(0, 15)],
xlabel='N',
equality_check=None
)
相关文章