It is my first post on StackOverflow.
I am writing a Mayavi Python program. Could anybody tell me how to update/modify the color of a point interactively? For example, in points3d(), changing the color of a point in real-time when I interactively modify its position.
I tried to do something under #on_trait_change, but it doesn't work. Color cannot be changed.
The following is my code:
import mayavi
import mayavi.mlab
from numpy import arange, pi, cos, sin
from traits.api import HasTraits, Range, Instance, \
from traitsui.api import View, Item, HGroup
from mayavi.core.api import PipelineBase
from mayavi.core.ui.api import MayaviScene, SceneEditor, \
def luc_func(x, y, z):
return x + y + z;
class Visualization(HasTraits):
x1 = Range(1, 30, 5)
z1 = Range(1, 30, 5)
scene = Instance(MlabSceneModel, ())
def __init__(self):
# Do not forget to call the parent's __init__
z = [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
y = [1,1,1,1,1,2,2,2,2,2,3,3,3,3,3,4,4,4,4,4,5,5,5,5,5]
x = [1,2,3,4,5,1,2,3,4,5,1,2,3,4,5,1,2,3,4,5,1,2,3,4,5]
self.plot = self.scene.mlab.points3d(x, y, z, luc_func, scale_mode = 'none')
#self.plot2 = self.scene.mlab.points3d(z, x, y, color = (0, 0, 1))
def update_plot(self):
x = [1,2,3,4,self.x1,1,2,3,4,self.x1,1,2,3,4,self.x1,1,2,3,4,self.x1,1,2,3,4,self.x1]
z = [1,1,1,1,self.z1,1,1,1,1,self.z1,1,1,1,1,self.z1,1,1,1,1,self.z1,1,1,1,1,self.z1]
luc_func = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,self.z1]
self.plot.mlab_source.reset(x = x, z = z, luc_func = luc_func)
#self.plot2.mlab_source.set(y = y, z = z)
# the layout of the dialog created
view = View(Item('scene', editor=SceneEditor(scene_class=MayaviScene),
height=250, width=300, show_label=False),
'_', 'x1', "z1",
visualization = Visualization()
Thanks for your help!
I have noticed a bug in the interactivity of points3d very similar to what you are describing here. I don't know exactly what is the origin of this bug but I regularly use the following workaround. The basic idea is to avoid mlab.points3d and instead call mlab.pipeline.glyph directly, as in:
def virtual_points3d(coords, figure=None, scale_factor=None, color=None,
c = np.array(coords)
source = mlab.pipeline.scalar_scatter( c[:,0], c[:,1], c[:,2],
return mlab.pipeline.glyph( source, scale_mode='none',
mode='sphere', figure=figure, color=color, name=name)
Later you can change the colors by referring to the vtk object directly, rather than the mayavi trait that isn't connected properly:
glyph = virtual_points3d(coords)
glyph.mlab_source.dataset.point_data.scalars = new_values
I am trying to understand the code below where it shows the default TSP path on a picture. I understand most of the part except for the polyfit_plot() function. I understand the function in it separately but when combine together I just don't get what it contributes to. I have even tried to delete the function and the result is actually the same, and I don't see where the function is implemented in. Can someone explain it to me?
import numpy as np
import math
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
united_states_map = mpimg.imread(r"C:\Users\user\Downloads\archive\united_states_map.png")
def show_cities(path, w=12, h=8):
if isinstance(path, dict): path = list(path.values())
if isinstance(path[0][0], str): path = [ item[1] for item in path ]
for x0, y0 in path:
plt.plot(x0, y0, 'y*', markersize=15) # y* = yellow star for starting point
fig = plt.gcf()
fig.set_size_inches([w, h])
def show_path(path, starting_city=None, w=12, h=8):
if isinstance(path, dict): path = list(path.values())
if isinstance(path[0][0], str): path = [ item[1] for item in path ]
starting_city = starting_city or path[0]
x, y = list(zip(*path))
#_, (x0, y0) = starting_city
(x0, y0) = starting_city
#plt.plot(x0, y0, 'y*', markersize=15) # y* = yellow star for starting point
plt.plot(x + x[:1], y + y[:1]) # include the starting point at the end of path
fig = plt.gcf()
fig.set_size_inches([w, h])
def polyfit_plot(x,y,deg, **kwargs):
coefficients = np.polyfit(x,y,deg,**kwargs)
poly = np.poly1d(coefficients)
new_x = np.linspace(x[0], x[-1])
new_y = poly(new_x)
plt.plot(x, y, "o", new_x, new_y)
plt.xlim([x[0]-1, x[-1] + 1 ])
terms = []
for p, c in enumerate(reversed(coefficients)):
term = str(round(c,1))
if p == 1: term += 'x'
if p >= 2: term += 'x^'+str(p)
plt.title(" + ".join(reversed(terms)))
cities = { "Oklahoma City": (392.8, 356.4), "Montgomery": (559.6, 404.8), "Saint Paul": (451.6, 186.0), "Trenton": (698.8, 239.6), "Salt Lake City": (204.0, 243.2), "Columbus": (590.8, 263.2), "Austin": (389.2, 448.4), "Phoenix": (179.6, 371.2), "Hartford": (719.6, 205.2), "Baton Rouge": (489.6, 442.0), "Salem": (80.0, 139.2), "Little Rock": (469.2, 367.2), "Richmond": (673.2, 293.6), "Jackson": (501.6, 409.6), "Des Moines": (447.6, 246.0), "Lansing": (563.6, 216.4), "Denver": (293.6, 274.0), "Boise": (159.6, 182.8), "Raleigh": (662.0, 328.8), "Atlanta": (585.6, 376.8), "Madison": (500.8, 217.6), "Indianapolis": (548.0, 272.8), "Nashville": (546.4, 336.8), "Columbia": (632.4, 364.8), "Providence": (735.2, 201.2), "Boston": (738.4, 190.8), "Tallahassee": (594.8, 434.8), "Sacramento": (68.4, 254.0), "Albany": (702.0, 193.6), "Harrisburg": (670.8, 244.0) }
cities = list(sorted(cities.items()))
It seems that the function polyfit_plot is unused in the code, so it is never run and as you say it does not affect the outcome.
I'm trying to do a simple button to "reset" widgets to certain default values. I'm using the #interact decorator in Jupyter Lab environment. The problem is that the widgets identifiers have their values copied to the same identifiers used as float variables inside the function and therefore I cannot access them anymore within this new scope. Here is a short example (not working):
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, Button
#interact(starts_at=(0, np.pi*0.9, np.pi*0.1), ends_at=(np.pi, 2*np.pi, np.pi*0.1))
def plot_graph(starts_at=0, ends_at=2*np.pi):
def on_button_clicked(_):
# instructions when clicking the button (this cannot work)
starts_at = 0
ends_at = 2*np.pi
button = Button(description="Reset")
f = lambda x : sum(1/a*np.sin(a*x + np.pi/a) for a in range(1,6))
x = np.linspace(0, 2*np.pi, 1000)
plt.plot(x, f(x))
plt.xlim([starts_at, ends_at])
Does anybody know how to send to the scope of the decorated function a reference to the original widget objects? I'll be accepting also simple ways of implementing a button to reset those sliders.
Edit: corrected text flow
To accomplish this you'll have to use the more manual interactive_output function. That function allows you to pre-create the widgets and then pass them in:
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
start_slider = widgets.FloatSlider(
val = 0,
min = 0,
max = np.pi*0.9,
step = np.pi*0.1,
description = 'Starts at'
end_slider = widgets.FloatSlider(
val = np.pi,
min = np.pi,
max = 2*np.pi,
step = np.pi*0.1,
description = 'Ends at'
def on_button_clicked(_):
start_slider.value = 0
end_slider.value = 2*np.pi
button = Button(description="Reset")
def plot_graph(starts_at=0, ends_at=2*np.pi):
f = lambda x : sum(1/a*np.sin(a*x + np.pi/a) for a in range(1,6))
x = np.linspace(0, 2*np.pi, 1000)
plt.plot(x, f(x))
plt.xlim([starts_at, ends_at])
display(widgets.VBox([start_slider, end_slider, button]))
widgets.interactive_output(plot_graph, {'starts_at': start_slider, 'ends_at':end_slider})
However, this will regenerate the plot entirely everytime you update it which can lead to a choppy experience. So you can also re-write this to use the matplotlib methods like .set_data if you use an interactive matplotlib backend in the notebook. So if you were to use ipympl you could follow the examples in this example notebook.
Via another library
I wrote a library mpl-interactions to make it easier to control matplotlib plots using ipywidgets sliders. It provides a function analogous to ipywidgets.interact in that it handles creating the widgets for you, but it has the advantage of being matplotlib focused so all you need to provide is the data. More about the differences to ipywidgets here
%matplotlib ipympl
import mpl_interactions.ipyplot as iplt
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
def plot_graph(starts_at=0, ends_at=2*np.pi):
x = np.linspace(starts_at, ends_at, 1000)
f = lambda x : sum(1/a*np.sin(a*x + np.pi/a) for a in range(1,6))
return np.array([x, f(x)]).T
fig, ax = plt.subplots()
button = widgets.Button(description = 'reset')
controls = iplt.plot(plot_graph, starts_at = (0, np.pi), ends_at = (np.pi, 2*np.pi), xlim='auto', parametric=True)
def on_click(event):
for hbox in controls.controls.values():
slider = hbox.children[0]
slider.value = slider.min
I could figure it out without using the #interact decorator (now it is working), but I'm not happy as a final result. So, I'm still willing to give the right answer status for someone that could make a clear/easier pythonic version of this.
Anyway, here is the working code:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, Button, FloatSlider
def plot_graph(starts_at, ends_at):
f = lambda x : sum(1/a*np.sin(a*x + np.pi/a) for a in range(1,6))
x = np.linspace(0, 2*np.pi, 1000)
plt.plot(x, f(x))
plt.xlim([starts_at, ends_at])
starts_at = FloatSlider(min=0, max=np.pi*0.9, value=0, step=np.pi*0.1)
ends_at = FloatSlider(min=np.pi, max=2*np.pi, value=2*np.pi, step=np.pi*0.1)
def on_button_clicked(_):
starts_at.value = 0
ends_at.value = 2*np.pi
button = Button(description="Reset")
_ = interact(plot_graph, starts_at=starts_at, ends_at=ends_at)
EDIT: NEW APPROACH FROM THIS POINT ==============================
I'm choosing #Ianhi answer as the correct because he pointed out the issues to be considered when in the context of my problem. Thanks!
Anyway, I'm posting here the final scheme I'm using which is simple enough for my needs and I can reuse my reset button in all my interacts:
# Preamble ----
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, Button, FloatSlider
def reset_button(defaults={}):
def on_button_clicked(_):
for k, v in defaults.items():
k.value = v
button = Button(description='Reset')
# Code ----
slider1 = FloatSlider(min=0, max=np.pi*0.9, value=0, step=np.pi*0.1)
slider2 = FloatSlider(min=np.pi, max=2*np.pi, value=2*np.pi, step=np.pi*0.1)
reset_button(defaults={slider1: 0, slider2: 2*np.pi})
#interact(starts_at=slider1, ends_at=slider2)
def plot_graph(starts_at, ends_at):
f = lambda x : sum(1/a*np.sin(a*x + np.pi/a) for a in range(1,6))
x = np.linspace(0, 2*np.pi, 1000)
plt.plot(x, f(x))
plt.xlim([starts_at, ends_at])
Building from #iperetta's answer you can create a simple decorator that will add a reset button for each use:
Reset button as per above answer:
def reset_button(defaults={}):
def on_button_clicked(_):
for k, v in defaults.items():
k.value = v
button = widgets.Button(description='Reset')
def interact_plus_reset(**_widgets):
default_vals = {wid:wid._trait_values['value'] for k, wid in _widgets.items()}
def wrap(func):
def inner(*args, **kwargs):
return func(*args, **kwargs)
return inner
return wrap
Then use it as below:
a = widgets.FloatSlider(min=1, max=2000, value=35, step=1),
b = widgets.FloatSlider(min=1, max=2000, value=1000, step=1),
def run(a, b):
print(a + b)
I also posted in the pyqtgraph forum here.
My overall goal is to have several clickable regions overlaid on an image, and if the plot boundary of any region is clicked I get a signal with the ID of that region. Something like this:
If I use only one PlotDataItem with nan-separated curves then each boundary sends the same signal. However, using a separate PlotDataItem for each boundary makes the application extremely sluggish.
I ended up subclassing ScatterPlotItem and rewriting the pointsAt function, which does what I want. The problem now is I can't figure out the appropriate way to change the ScatterPlotItem's boundingRect. Am I on the right approach? Is there a better way of doing this?
import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore, QtGui
class CustScatter(pg.ScatterPlotItem):
def pointsAt(self, pos: QtCore.QPointF):
The default implementation only checks a square around each spot. However, this is not
precise enough for my needs. It also triggers when clicking *inside* the spot boundary,
which I don't want.
pts = []
for spot in self.points(): # type: pg.SpotItem
symb = QtGui.QPainterPath(spot.symbol())
stroker = QtGui.QPainterPathStroker()
mousePath = stroker.createStroke(symb)
# Only trigger when clicking a boundary, not the inside of the shape
if mousePath.contains(pos):
return pts[::-1]
"""Make some sample data"""
tri = np.array([[0,2.3,0,1,4,5,0], [0,4,4,8,8,3,0]]).T
tris = []
xyLocs = []
datas = []
for ii in np.arange(0, 16, 5):
curTri = tri + ii
def ptsClicked(item, pts):
print(f'ID {pts[0].data()} Clicked!')
"""Logic for making spot shapes from a list of (x,y) vertices"""
def makeSymbol(verts: np.ndarray):
outSymbol = QtGui.QPainterPath()
symPath = pg.arrayToQPath(*verts.T)
# From pyqtgraph.examples for plotting text
br = outSymbol.boundingRect()
tr = QtGui.QTransform()
tr.translate(-br.x(), -br.y())
outSymbol =
return outSymbol
app = pg.mkQApp()
pg.setConfigOption('background', 'w')
symbs = []
for xyLoc, tri in zip(xyLocs, tris):
"""Create the scatterplot"""
xyLocs = np.vstack(xyLocs)
tri2 = pg.PlotDataItem()
scat = CustScatter(*xyLocs.T, symbol=symbs, data=datas, connect='finite',
pxMode=False, brush=None, pen=pg.mkPen(width=5), size=1)
# Now each 'point' is one of the triangles, hopefully
"""Construct GUI window"""
w = pg.PlotWindow()
plt: pg.PlotItem = w.plotItem
plt.showGrid(True, True, 1)
Solved! It turns out unless you specify otherwise, the boundingRect of each symbol in the dataset is assumed to be 1 and that the spot size is the limiting factor. After overriding measureSpotSizes as well my solution works:
import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore, QtGui
class CustScatter(pg.ScatterPlotItem):
def pointsAt(self, pos: QtCore.QPointF):
The default implementation only checks a square around each spot. However, this is not
precise enough for my needs. It also triggers when clicking *inside* the spot boundary,
which I don't want.
pts = []
for spot in self.points(): # type: pg.SpotItem
symb = QtGui.QPainterPath(spot.symbol())
stroker = QtGui.QPainterPathStroker()
mousePath = stroker.createStroke(symb)
# Only trigger when clicking a boundary, not the inside of the shape
if mousePath.contains(pos):
return pts[::-1]
def measureSpotSizes(self, dataSet):
Override the method so that it takes symbol size into account
for rec in dataSet:
## keep track of the maximum spot size and pixel size
symbol, size, pen, brush = self.getSpotOpts(rec)
br = symbol.boundingRect()
size = max(br.width(), br.height())*2
width = 0
pxWidth = 0
if self.opts['pxMode']:
pxWidth = size + pen.widthF()
width = size
if pen.isCosmetic():
pxWidth += pen.widthF()
width += pen.widthF()
self._maxSpotWidth = max(self._maxSpotWidth, width)
self._maxSpotPxWidth = max(self._maxSpotPxWidth, pxWidth)
self.bounds = [None, None]
"""Make some sample data"""
tri = np.array([[0,2.3,0,1,4,5,0], [0,4,4,8,8,3,0]]).T
tris = []
xyLocs = []
datas = []
for ii in np.arange(0, 16, 5):
curTri = tri + ii
def ptsClicked(item, pts):
print(f'ID {pts[0].data()} Clicked!')
"""Logic for making spot shapes from a list of (x,y) vertices"""
def makeSymbol(verts: np.ndarray):
plotVerts = verts - verts.min(0, keepdims=True)
symPath = pg.arrayToQPath(*plotVerts.T)
return symPath
app = pg.mkQApp()
pg.setConfigOption('background', 'd')
symbs = []
for xyLoc, tri in zip(xyLocs, tris):
"""Create the scatterplot"""
xyLocs = np.vstack(xyLocs)
tri2 = pg.PlotDataItem()
scat = CustScatter(*xyLocs.T, symbol=symbs, data=datas, connect='finite',
pxMode=False, brush=None, pen=pg.mkPen(width=5), size=1)
# Now each 'point' is one of the triangles, hopefully
"""Construct GUI window"""
w = pg.PlotWindow()
plt: pg.PlotItem = w.plotItem
plt.showGrid(True, True, 1)
I am trying to update a 3D plot using matplotlib. I am collecting data using ROS. I want to update the plot as I get data. I have looked around and found this,
Dynamically updating plot in matplotlib
but I cannot get it to work. I am very new to python and do not full understand how it works yet. I apologize if my code is disgusting.
I keep get this error.
[ERROR] [WallTime: 1435801577.604410] bad callback: <function usbl_move at 0x7f1e45c4c5f0>
Traceback (most recent call last):
File "/opt/ros/indigo/lib/python2.7/dist-packages/rospy/", line 709, in _invoke_callback
cb(msg, cb_args)
File "/home/nathaniel/simulation/src/move_videoray/src/", line 63, in usbl_move
if filter(pos.pose.position.x,pos.pose.position.y,current.position.z):
File "/home/nathaniel/simulation/src/move_videoray/src/", line 127, in filter
File "/usr/lib/pymodules/python2.7/matplotlib/", line 555, in draw
File "/usr/lib/pymodules/python2.7/matplotlib/backends/", line 349, in draw
tkagg.blit(self._tkphoto, self.renderer._renderer, colormode=2)
File "/usr/lib/pymodules/python2.7/matplotlib/backends/", line 13, in blit"PyAggImagePhoto", photoimage, id(aggimage), colormode, id(bbox_array))
RuntimeError: main thread is not in main loop
This is the code I am trying to run
#!/usr/bin/env python
Ths program moves the videoray model in rviz using
data from the /usble_pose node
based on "Using urdf with robot_state_publisher" tutorial
import rospy
import roslib
import math
import tf
#import outlier_filter
from geometry_msgs.msg import Twist, Vector3, Pose, PoseStamped, TransformStamped
from matplotlib import matplotlib_fname
from mpl_toolkits.mplot3d import Axes3D
import sys
from matplotlib.pyplot import plot
from numpy import mean, std
import matplotlib as mpl
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import matplotlib
mpl.rc("savefig", dpi=150)
import matplotlib.animation as animation
import time
#filter stuff
#window size
n = 10
#make some starting values
#random distance
md =[random.random() for _ in range(0, n)]
#random points
x_list = [random.random() for _ in range(0, n)]
y_list =[random.random() for _ in range(0, n)]
#set up graph
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_title('XY Outlier rejection \n with Mahalanobis distance and rolling mean3')
#set the robot at the center
#//move the videoray using the data from the /pose_only node
def usbl_move(pos,current):
broadcaster = tf.TransformBroadcaster()
if filter(pos.pose.position.x,pos.pose.position.y,current.position.z):
current.position.x = pos.pose.position.x
current.position.y = pos.pose.position.y
broadcaster.sendTransform( (current.position.x,current.position.y,current.position.z),
(current.orientation.x,current.orientation.y,current.orientation.z,current.orientation.w),, "odom", "body" )
#move the videoray using the data from the /pose_only node
def pose_move(pos,current):
#pos.position.z is in kPa, has to be convereted to depth
# P = P0 + pgz ----> pos.position.z = P0 + pg*z_real
z_real = -1*(pos.position.z -101.325)/9.81;
#update the movement
broadcaster = tf.TransformBroadcaster()
current.orientation.x = pos.orientation.x
current.orientation.y = pos.orientation.y
current.orientation.z = pos.orientation.z
current.orientation.w = pos.orientation.w
current.position.z = z_real
broadcaster.sendTransform( (current.position.x,current.position.y,current.position.z),
(current.orientation.x,current.orientation.y,current.orientation.z,current.orientation.w),, "odom", "body" )
#call the fitle the date
def filter(x,y,z):
# update the window
is_good = False
#get the covariance matrix
v = np.linalg.inv(np.cov(x_list,y_list,rowvar=0))
#get the mean vector
r_mean = mean(x_list), mean(y_list)
#subtract the mean vector from the point
x_diff = np.array([i - r_mean[0] for i in x_list])
y_diff = np.array([i - r_mean[1] for i in y_list])
#combinded and transpose the x,y diff matrix
diff_xy = np.transpose([x_diff, y_diff])
# calculate the Mahalanobis distance
dis = np.sqrt([n-1]),v),diff_xy[n-1]))
# update the window
md.append( dis)
#find mean and standard standard deviation of the standard deviation list
mu = np.mean(md)
sigma = np.std(md)
#threshold to find if a outlier
if dis < mu + 1.5*sigma:
is_good = True
return is_good
if __name__ == '__main__':
#set up the node
rospy.init_node('move_unfiltered', anonymous=True)
#make a broadcaster foir the tf frame
broadcaster = tf.TransformBroadcaster()
#make intilial values
current = Pose()
current.position.x = 0
current.position.y = 0
current.position.z = 0
current.orientation.x = 0
current.orientation.y = 0
current.orientation.z = 0
current.orientation.w = 0
#send the tf frame
broadcaster.sendTransform( (current.position.x,current.position.y,current.position.z),
(current.orientation.x,current.orientation.y,current.orientation.z,current.orientation.w),, "odom", "body" )
#listen for information
rospy.Subscriber("/usbl_pose", PoseStamped, usbl_move,current)
rospy.Subscriber("/pose_only", Pose, pose_move, current);
Since this is an old post and still seems to be active in the community, I am going to provide an example, in general, how can we do real-time plotting. Here I used matplotlib FuncAnimation function.
import matplotlib.pyplot as plt
import rospy
import tf
from nav_msgs.msg import Odometry
from tf.transformations import quaternion_matrix
import numpy as np
from matplotlib.animation import FuncAnimation
class Visualiser:
def __init__(self):
self.fig, = plt.subplots()
self.ln, = plt.plot([], [], 'ro')
self.x_data, self.y_data = [] , []
def plot_init(self):, 10000), 7)
return self.ln
def getYaw(self, pose):
quaternion = (pose.orientation.x, pose.orientation.y, pose.orientation.z,
euler = tf.transformations.euler_from_quaternion(quaternion)
yaw = euler[2]
return yaw
def odom_callback(self, msg):
yaw_angle = self.getYaw(msg.pose.pose)
x_index = len(self.x_data)
def update_plot(self, frame):
self.ln.set_data(self.x_data, self.y_data)
return self.ln
vis = Visualiser()
sub = rospy.Subscriber('/dji_sdk/odometry', Odometry, vis.odom_callback)
ani = FuncAnimation(vis.fig, vis.update_plot, init_func=vis.plot_init)
Using the minimal example below, the line plot of a large (some 110k points) plot I get (with python 2.7, numpy 1.5.1, chaco/enable/traits 4.3.0) is this:
However, that is bizarre, because it is a line plot, and there shouldn't be any filled areas in there? Especially since the data is sawtooth-ish signal? It's as if there is a line at y~=37XX, above which there is color filling?! But sure enough, if I zoom into an area, I get the rendering I expect - without the unexpected fill:
Is this a bug - or is there something I'm doing wrong? I tried to use use_downsampling, but it makes no difference...
The test code:
import numpy as np
import numpy.random as npr
from pprint import pprint
from traits.api import HasTraits, Instance
from chaco.api import Plot, ArrayPlotData, VPlotContainer
from traitsui.api import View, Item
from enable.component_editor import ComponentEditor
from import PanTool, BetterSelectingZoom
tlen = 112607
alr = npr.randint(0, 4000, tlen)
tx = np.arange(0.0, 30.0-0.00001, 30.0/tlen)
ty = np.arange(0, tlen, 1) % 10000 + alr
class ChacoTest(HasTraits):
container = Instance(VPlotContainer)
traits_view = View(
Item('container', editor=ComponentEditor(), show_label=False),
width=800, height=500, resizable=True,
title="Chaco Test"
def __init__(self):
super(ChacoTest, self).__init__()
self.plotdata = ArrayPlotData(x = tx, y = ty)
self.plotobj = Plot(self.plotdata)
self.plotA = self.plotobj.plot(("x", "y"), type="line", color=(0,0.99,0), spacing=0, padding=0, alpha=0.7, use_downsampling=True)
self.container = VPlotContainer(self.plotobj, spacing=5, padding=5, bgcolor="lightgray")
#~ container.add(plot)
if __name__ == "__main__":
I am able to reproduce the error and talking with John Wiggins (maintainer of Enable), it is a bug in kiva (which chaco uses to paint on the screen):
The good news is that this is a bug in one of the kiva backend that you can use. So to go around the issue, you can run your script choosing a different backend:
ETS_TOOLKIT=qt4.qpainter python <NAME OF YOUR SCRIPT>
if you use qpainter or quartz, the plot looks (on my machine) as expected. If you choose qt4.image (the Agg backend), you will reproduce the issue. Unfortunately, the Agg backend is the default one. To change that, you can set the ETS_TOOLKIT environment variable to that value:
export ETS_TOOLKIT=qt4.qpainter
The bad news is that fixing this isn't going to be an easy task. Please feel free to report the bug in github (again if you want to be involved in this. If you don't, I will log it in the next couple of days. Thanks for reporting it!
Just a note - I found this:
[Enthought-Dev] is chaco faster than matplotlib
I recall reading somewhere that you are expected to implement the
_downsample method because the optimal algorithm depends on the type
of data you're collecting.
And as I couldn't find any examples with _downsample implementation other than referred in that post, which isn't standalone - I tried and built a standalone example, included below.
The example basically has messed up drag and zoom, (plot disappears if you go out of range, or stretches upon a drag move) - and it starts zoomed in; but it is possible to zoom it out in the range shown in the OP - and then it displays the exact same plot rendering problem. So downsampling isn't the solution per se, so this is likely a bug?
import numpy as np
import numpy.random as npr
from pprint import pprint
from traits.api import HasTraits, Instance
from chaco.api import Plot, ArrayPlotData, VPlotContainer
from traitsui.api import View, Item
from enable.component_editor import ComponentEditor
from import PanTool, BetterSelectingZoom
from chaco.api import BaseXYPlot, LinearMapper, AbstractPlotData
from enable.api import black_color_trait, LineStyle
from traits.api import Float, Enum, Int, Str, Trait, Event, Property, Array, cached_property, Bool, Dict
from chaco.abstract_mapper import AbstractMapper
from chaco.abstract_data_source import AbstractDataSource
from chaco.array_data_source import ArrayDataSource
from chaco.data_range_1d import DataRange1D
tlen = 112607
alr = npr.randint(0, 4000, tlen)
tx = np.arange(0.0, 30.0-0.00001, 30.0/tlen)
ty = np.arange(0, tlen, 1) % 10000 + alr
class ChacoTest(HasTraits):
container = Instance(VPlotContainer)
traits_view = View(
Item('container', editor=ComponentEditor(), show_label=False),
width=800, height=500, resizable=True,
title="Chaco Test"
downsampling_cutoff = Int(4)
def __init__(self):
super(ChacoTest, self).__init__()
self.plotdata = ArrayPlotData(x = tx, y = ty)
self.plotobj = TimeSeriesPlot(self.plotdata)
self.plotobj.setplotranges("x", "y")
self.container = VPlotContainer(self.plotobj, spacing=5, padding=5, bgcolor="lightgray")
# decimate from:
def decimate(data, screen_width, downsampling_cutoff=4, mode='extremes'):
data_width = data.shape[-1]
downsample = np.floor((data_width/screen_width)/4.)
if downsample > downsampling_cutoff:
return globals()['decimate_'+mode](data, downsample)
return data
def decimate_extremes(data, downsample):
last_dim = data.ndim
offset = data.shape[-1] % downsample
if data.ndim == 2:
shape = (len(data), -1, downsample)
shape = (-1, downsample)
data = data[..., offset:].reshape(shape).copy()
data_min = data.min(last_dim)
data_max = data.max(last_dim)
return data_min, data_max
def decimate_mean(data, downsample):
offset = len(data) % downsample
if data.ndim == 2:
shape = (-1, downsample, data.shape[-1])
shape = (-1, downsample)
data = data[offset:].reshape(shape).copy()
return data.mean(1)
# based on class from, also
# neurobehavior/cns/chaco_exts/ ;
# + some other code from chaco
class TimeSeriesPlot(BaseXYPlot):
color = black_color_trait
line_width = Float(1.0)
line_style = LineStyle
reference = Enum('most_recent', 'trigger')
traits_view = View("color#", "line_width")
downsampling_cutoff = Int(100)
signal_trait = "updated"
decimate_mode = Str('extremes')
ch_index = Trait(None, Int, None)
# Mapping of data names from to their respective datasources.
datasources = Dict(Str, Instance(AbstractDataSource))
index_mapper = Instance(AbstractMapper)
value_mapper = Instance(AbstractMapper)
def __init__(self, data=None, **kwargs):
super(TimeSeriesPlot, self).__init__(**kwargs)
self._index_mapper_changed(None, self.index_mapper)
self._plot_ui_info = None
def setplotdata(self, data):
if data is not None:
if isinstance(data, AbstractPlotData): = data
elif type(data) in (ndarray, tuple, list): = ArrayPlotData(data)
raise ValueError, "Don't know how to create PlotData for data" \
"of type " + str(type(data))
def setplotranges(self, index_name, value_name):
self.index_name = index_name
self.value_name = value_name
index = self._get_or_create_datasource(index_name)
value = self._get_or_create_datasource(value_name)
if not(self.index_mapper):
imap = LinearMapper()#(range=self.index_range)
self.index_mapper = imap
if not(self.value_mapper):
vmap = LinearMapper()#(range=self.value_range)
self.value_mapper = vmap
if not(self.index_range): self.index_range = DataRange1D() # calls index_mapper
if not(self.value_range): self.value_range = DataRange1D()
self.index_range.add(index) # calls index_mapper!
# now do it (right?):
self.index_mapper = LinearMapper(range=self.index_range)
self.value_mapper = LinearMapper(range=self.value_range)
def _get_or_create_datasource(self, name):
if name not in self.datasources:
data =
if type(data) in (list, tuple):
data = array(data)
if isinstance(data, np.ndarray):
if len(data.shape) == 1:
ds = ArrayDataSource(data, sort_order="none")
elif len(data.shape) == 2:
ds = ImageData(data=data, value_depth=1)
elif len(data.shape) == 3:
if data.shape[2] in (3,4):
ds = ImageData(data=data, value_depth=int(data.shape[2]))
raise ValueError("Unhandled array shape in creating new plot: " \
+ str(data.shape))
elif isinstance(data, AbstractDataSource):
ds = data
raise ValueError("Couldn't create datasource for data of type " + \
self.datasources[name] = ds
return self.datasources[name]
def get_screen_points(self):
return self._downsample()
def _data_changed(self):
self._cache_valid = False
self._screen_cache_valid = False
def _gather_points(self):
if not self._cache_valid:
range = self.index_mapper.range
#if self.reference == 'most_recent':
# values, t_lb, t_ub = self.get_recent_range(range.low, range.high)
# values, t_lb, t_ub = self.get_range(range.low, range.high, -1)
values, t_lb, t_ub =[self.value_name][range.low:range.high], range.low, range.high
#if self.ch_index is None:
# self._cached_data = values
# #self._cached_data = values[:,self.ch_index]
self._cached_data = values
self._cached_data_bounds = t_lb, t_ub
self._cache_valid = True
self._screen_cache_valid = False
def _downsample(self):
if not self._screen_cache_valid:
val_pts = self._cached_data
screen_min, screen_max = self.index_mapper.screen_bounds
screen_width = screen_max-screen_min
values = decimate(val_pts, screen_width, self.downsampling_cutoff,
if type(values) == type(()):
n = len(values[0])
s_val_min = self.value_mapper.map_screen(values[0])
s_val_max = self.value_mapper.map_screen(values[1])
self._cached_screen_data = s_val_min, s_val_max
s_val_pts = self.value_mapper.map_screen(values)
self._cached_screen_data = s_val_pts
n = len(values)
t = np.linspace(*self._cached_data_bounds, num=n)
t_screen = self.index_mapper.map_screen(t)
self._cached_screen_index = t_screen
self._screen_cache_valid = True
return [self._cached_screen_index, self._cached_screen_data]
def _render(self, gc, points):
idx, val = points
if len(idx) == 0:
gc.clip_to_rect(self.x, self.y, self.width, self.height)
#if len(val) == 2:
if type(val) == type(()):
starts = np.column_stack((idx, val[0]))
ends = np.column_stack((idx, val[1]))
gc.line_set(starts, ends)
gc.lines(np.column_stack((idx, val)))
if __name__ == "__main__":