如何更新可滚动、matplotlib 和 multiplot 中的艺术家
问题描述
我正在尝试根据这个问题的答案创建一个可滚动的多图:
我不能应用与 plot()
相同的方法,因为以下会产生错误消息:
ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)TypeError:PolyCollection"对象不可迭代
这就是它在每一帧上的样子
解决方案fill_between
返回一个 PolyCollection
,它在创建时需要一个(或多个)顶点列表.不幸的是,我还没有找到一种方法来检索用于创建给定 PolyCollection
的顶点,但在您的情况下,直接创建 PolyCollection
很容易(因此避免使用 fill_between
),然后在帧更改时更新其顶点.
下面的代码版本可以满足您的需求:
将 numpy 导入为 np将熊猫导入为 pd将 matplotlib.pyplot 导入为 plt将 matplotlib.gridspec 导入为 gridspec从 matplotlib.widgets 导入滑块从 matplotlib.collections 导入 PolyCollection#创建数据框dfs={}对于范围内的 x(100):col1=np.random.normal(10,0.5,30)col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]col3=np.random.randint(4,size=30)dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})#创建图形,轴,子图无花果 = plt.figure()gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)ax = plt.subplot(gs[0])ax.set_ylim([0,12])#滑块帧=0axframe = plt.axes([0.13, 0.02, 0.75, 0.03])sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')#地块ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')##additional 代码来更新 PolyCollectionsval_r = 5val_b = 8val_g = 7def update_collection(collection, value, frame = 0):xs = np.array(dfs[frame].index)ys = np.array(dfs[frame]['col2'])##我们需要捕捉不存在 y == 值的点的情况:尝试:minx = np.min(xs[ys == 值])maxx = np.max(xs[ys == 值])最小=值-0.5最大值 = 值+0.5verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])除了ValueError:顶点 = np.zeros((0,2))最后:collection.set_verts([verts])#艺术家##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)ax.add_collection(红色)update_collection(reds,val_r)##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)ax.add_collection(蓝调)update_collection(蓝调,val_b)##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)ax.add_collection(绿色)update_collection(绿色,val_g)ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')#更新地块定义更新(验证):框架 = np.floor(sframe.val)ln1.set_ydata(dfs[frame]['col1'])ln2.set_ydata(dfs[frame]['col2'])ax.set_title('Frame' + str(int(frame)))##更新PolyCollections:update_collection(reds,val_r, frame)update_collection(蓝调,val_b,帧)update_collection(果岭,val_g,帧)plt.draw()#connect 回调到滑块sframe.on_changed(更新)plt.show()
三个 PolyCollections
(reds
、blues
和 greens
)中的每一个都只有四个顶点(矩形的边缘),这是根据给定的数据确定的(在 update_collections
中完成).结果如下所示:
在 Python 3.5 中测试
I'm trying to create a scrollable multiplot based on the answer to this question: Creating a scrollable multiplot with python's pylab
Lines created using ax.plot()
are updating correctly, however I'm unable to figure out how to update artists created using xvlines()
and fill_between()
.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider
#create dataframes
dfs={}
for x in range(100):
col1=np.random.normal(10,0.5,30)
col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
col3=np.random.randint(4,size=30)
dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})
#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])
#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')
#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')
#artists
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')
#update plots
def update(val):
frame = np.floor(sframe.val)
ln1.set_ydata(dfs[frame]['col1'])
ln2.set_ydata(dfs[frame]['col2'])
ax.set_title('Frame ' + str(int(frame)))
plt.draw()
#connect callback to slider
sframe.on_changed(update)
plt.show()
This is what it looks like at the moment
I can't apply the same approach as for plot()
, since the following produces an error message:
ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
TypeError: 'PolyCollection' object is not iterable
This is what it's meant to look like on each frame
解决方案fill_between
returns a PolyCollection
, which expects a list (or several lists) of vertices upon creation. Unfortunately I haven't found a way to retrieve the vertices that where used to create the given PolyCollection
, but in your case it is easy enough to create the PolyCollection
directly (thereby avoiding the use of fill_between
) and then update its vertices upon frame change.
Below a version of your code that does what you are after:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider
from matplotlib.collections import PolyCollection
#create dataframes
dfs={}
for x in range(100):
col1=np.random.normal(10,0.5,30)
col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
col3=np.random.randint(4,size=30)
dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})
#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])
#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')
#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')
##additional code to update the PolyCollections
val_r = 5
val_b = 8
val_g = 7
def update_collection(collection, value, frame = 0):
xs = np.array(dfs[frame].index)
ys = np.array(dfs[frame]['col2'])
##we need to catch the case where no points with y == value exist:
try:
minx = np.min(xs[ys == value])
maxx = np.max(xs[ys == value])
miny = value-0.5
maxy = value+0.5
verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])
except ValueError:
verts = np.zeros((0,2))
finally:
collection.set_verts([verts])
#artists
##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)
ax.add_collection(reds)
update_collection(reds,val_r)
##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)
ax.add_collection(blues)
update_collection(blues, val_b)
##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)
ax.add_collection(greens)
update_collection(greens, val_g)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')
#update plots
def update(val):
frame = np.floor(sframe.val)
ln1.set_ydata(dfs[frame]['col1'])
ln2.set_ydata(dfs[frame]['col2'])
ax.set_title('Frame ' + str(int(frame)))
##updating the PolyCollections:
update_collection(reds,val_r, frame)
update_collection(blues,val_b, frame)
update_collection(greens,val_g, frame)
plt.draw()
#connect callback to slider
sframe.on_changed(update)
plt.show()
Each of the three PolyCollections
(reds
, blues
, and greens
) has only four vertices (the edges of the rectangles), which are determined based on the given data (which is done in update_collections
). The result looks like this:
Tested in Python 3.5
相关文章