如何更新可滚动、matplotlib 和 multiplot 中的艺术家

2022-01-24 00:00:00 python matplotlib slider

问题描述

我正在尝试根据这个问题的答案创建一个可滚动的多图:

我不能应用与 plot() 相同的方法,因为以下会产生错误消息:

ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)TypeError:PolyCollection"对象不可迭代

这就是它在每一帧上的样子

解决方案

fill_between 返回一个 PolyCollection,它在创建时需要一个(或多个)顶点列表.不幸的是,我还没有找到一种方法来检索用于创建给定 PolyCollection 的顶点,但在您的情况下,直接创建 PolyCollection 很容易(因此避免使用 fill_between),然后在帧更改时更新其顶点.

下面的代码版本可以满足您的需求:

将 numpy 导入为 np将熊猫导入为 pd将 matplotlib.pyplot 导入为 plt将 matplotlib.gridspec 导入为 gridspec从 matplotlib.widgets 导入滑块从 matplotlib.collections 导入 PolyCollection#创建数据框dfs={}对于范围内的 x(100):col1=np.random.normal(10,0.5,30)col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]col3=np.random.randint(4,size=30)dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})#创建图形,轴,子图无花果 = plt.figure()gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)ax = plt.subplot(gs[0])ax.set_ylim([0,12])#滑块帧=0axframe = plt.axes([0.13, 0.02, 0.75, 0.03])sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')#地块ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')##additional 代码来更新 PolyCollectionsval_r = 5val_b = 8val_g = 7def update_collection(collection, value, frame = 0):xs = np.array(dfs[frame].index)ys = np.array(dfs[frame]['col2'])##我们需要捕捉不存在 y == 值的点的情况:尝试:minx = np.min(xs[ys == 值])maxx = np.max(xs[ys == 值])最小=值-0.5最大值 = 值+0.5verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])除了ValueError:顶点 = np.zeros((0,2))最后:collection.set_verts([verts])#艺术家##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)ax.add_collection(红色)update_collection(reds,val_r)##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)ax.add_collection(蓝调)update_collection(蓝调,val_b)##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)ax.add_collection(绿色)update_collection(绿色,val_g)ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')#更新地块定义更新(验证):框架 = np.floor(sframe.val)ln1.set_ydata(dfs[frame]['col1'])ln2.set_ydata(dfs[frame]['col2'])ax.set_title('Frame' + str(int(frame)))##更新PolyCollections:update_collection(reds,val_r, frame)update_collection(蓝调,val_b,帧)update_collection(果岭,val_g,帧)plt.draw()#connect 回调到滑块sframe.on_changed(更新)plt.show()

三个 PolyCollections(redsbluesgreens)中的每一个都只有四个顶点(矩形的边缘),这是根据给定的数据确定的(在 update_collections 中完成).结果如下所示:

在 Python 3.5 中测试

I'm trying to create a scrollable multiplot based on the answer to this question: Creating a scrollable multiplot with python's pylab

Lines created using ax.plot() are updating correctly, however I'm unable to figure out how to update artists created using xvlines() and fill_between().

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider

#create dataframes
dfs={}
for x in range(100):
    col1=np.random.normal(10,0.5,30)
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
    col3=np.random.randint(4,size=30)
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})

#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])

#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')

#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')

#artists
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')

#update plots
def update(val):
    frame = np.floor(sframe.val)
    ln1.set_ydata(dfs[frame]['col1'])
    ln2.set_ydata(dfs[frame]['col2'])
    ax.set_title('Frame ' + str(int(frame)))
    plt.draw()

#connect callback to slider 
sframe.on_changed(update)
plt.show()

This is what it looks like at the moment

I can't apply the same approach as for plot(), since the following produces an error message:

ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
TypeError: 'PolyCollection' object is not iterable

This is what it's meant to look like on each frame

解决方案

fill_between returns a PolyCollection, which expects a list (or several lists) of vertices upon creation. Unfortunately I haven't found a way to retrieve the vertices that where used to create the given PolyCollection, but in your case it is easy enough to create the PolyCollection directly (thereby avoiding the use of fill_between) and then update its vertices upon frame change.

Below a version of your code that does what you are after:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider

from matplotlib.collections import PolyCollection

#create dataframes
dfs={}
for x in range(100):
    col1=np.random.normal(10,0.5,30)
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
    col3=np.random.randint(4,size=30)
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})

#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])

#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')

#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')

##additional code to update the PolyCollections
val_r = 5
val_b = 8
val_g = 7

def update_collection(collection, value, frame = 0):
    xs = np.array(dfs[frame].index)
    ys = np.array(dfs[frame]['col2'])

    ##we need to catch the case where no points with y == value exist:
    try:
        minx = np.min(xs[ys == value])
        maxx = np.max(xs[ys == value])
        miny = value-0.5
        maxy = value+0.5
        verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])
    except ValueError:
        verts = np.zeros((0,2))
    finally:
        collection.set_verts([verts])

#artists

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)
ax.add_collection(reds)
update_collection(reds,val_r)

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)
ax.add_collection(blues)
update_collection(blues, val_b)

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)
ax.add_collection(greens)
update_collection(greens, val_g)

ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')

#update plots
def update(val):
    frame = np.floor(sframe.val)
    ln1.set_ydata(dfs[frame]['col1'])
    ln2.set_ydata(dfs[frame]['col2'])
    ax.set_title('Frame ' + str(int(frame)))

    ##updating the PolyCollections:
    update_collection(reds,val_r, frame)
    update_collection(blues,val_b, frame)
    update_collection(greens,val_g, frame)

    plt.draw()

#connect callback to slider 
sframe.on_changed(update)
plt.show()

Each of the three PolyCollections (reds, blues, and greens) has only four vertices (the edges of the rectangles), which are determined based on the given data (which is done in update_collections). The result looks like this:

Tested in Python 3.5

相关文章