Related
I'm trying to parse tables in a PDF using Camelot. The cells have multiple lines of texts in them, and some have an empty line separating portions of the text:
First line
Second line
Third line
I would expect this to be parsed as First line\nSecond line\n\nThird line (notice the double line breaks), but I get this instead: T\nFirst line\nSecond line\nhird line. The first character after a double-line-break moves to the beginning of the text, and I only get a single line-break instead.
I also tried using tabula, but that one messes up de entire table (data-frame actually) when there is an empty row in the table, and also in case of some words it puts a space between the characters.
EDIT:
My main issue is the removal of multiple line-breaks. The other one I could fix from code if I knew where the empty lines were.
my friend, can you check the example here
https://camelot-py.readthedocs.io/en/master/user/advanced.html#improve-guessed-table-rows
tables = camelot.read_pdf('group_rows.pdf', flavor='stream', row_tol=10)
tables[0].df
I solved the same problem with the code below
tables = camelot.read_pdf(file, flavor = 'stream', table_areas=['24,618,579,93'], columns=['67,315,369,483,571'], row_tol=10,strip_text='t\r\n\v')
I also encountered the same problem in case of a double line break. It was Switching Characters around as its doing in your case. I Spent some time looking at the code and i did some changes and fixed the issue. You can use the below code.
After Adding the below code, instead of using camelot.read_pdf, use the custom method i made read_pdf_custom()
And for a better experience, i suggest you using camelot v==0.8.2
import sys
import warnings
from camelot import read_pdf
from camelot import handlers
from camelot.core import TableList
from camelot.parsers import Lattice
from camelot.parsers.base import BaseParser
from camelot.core import Table
import camelot
from camelot.utils import validate_input, remove_extra,TemporaryDirectory,get_page_layout,get_text_objects,get_rotation,is_url,download_url,scale_image,scale_pdf,segments_in_bbox,text_in_bbox,merge_close_lines,get_table_index,compute_accuracy,compute_whitespace
from camelot.image_processing import (
adaptive_threshold,
find_lines,
find_contours,
find_joints,
)
class custom_lattice(Lattice):
def _generate_columns_and_rows(self, table_idx, tk):
# select elements which lie within table_bbox
t_bbox = {}
v_s, h_s = segments_in_bbox(
tk, self.vertical_segments, self.horizontal_segments
)
custom_horizontal_indexes=[]
custom_vertical_indexes=[]
for zzz in self.horizontal_text:
try:
h_extracted_text=self.find_between(str(zzz),"'","'").strip()
h_text_index=self.find_between(str(zzz),"LTTextLineHorizontal","'").strip().split(",")
custom_horizontal_indexes.append(h_text_index[1])
except:
pass
inserted=0
for xxx in self.vertical_text:
v_extracted_text=self.find_between(str(xxx),"'","'").strip()
v_text_index=self.find_between(str(xxx),"LTTextLineVertical","'").strip().split(",")
custom_vertical_indexes.append(v_text_index[1])
vertical_second_index=v_text_index[1]
try:
horizontal_index=custom_horizontal_indexes.index(vertical_second_index)
self.horizontal_text.insert(horizontal_index,xxx)
except Exception as exxx:
pass
self.vertical_text=[]
t_bbox["horizontal"] = text_in_bbox(tk, self.horizontal_text)
t_bbox["vertical"] = text_in_bbox(tk, self.vertical_text)
t_bbox["horizontal"].sort(key=lambda x: (-x.y0, x.x0))
t_bbox["vertical"].sort(key=lambda x: (x.x0, -x.y0))
self.t_bbox = t_bbox
cols, rows = zip(*self.table_bbox[tk])
cols, rows = list(cols), list(rows)
cols.extend([tk[0], tk[2]])
rows.extend([tk[1], tk[3]])
cols = merge_close_lines(sorted(cols), line_tol=self.line_tol)
rows = merge_close_lines(sorted(rows, reverse=True), line_tol=self.line_tol)
cols = [(cols[i], cols[i + 1]) for i in range(0, len(cols) - 1)]
rows = [(rows[i], rows[i + 1]) for i in range(0, len(rows) - 1)]
return cols, rows, v_s, h_s
def _generate_table(self, table_idx, cols, rows, **kwargs):
print("\n")
v_s = kwargs.get("v_s")
h_s = kwargs.get("h_s")
if v_s is None or h_s is None:
raise ValueError("No segments found on {}".format(self.rootname))
table = Table(cols, rows)
table = table.set_edges(v_s, h_s, joint_tol=self.joint_tol)
table = table.set_border()
table = table.set_span()
pos_errors = []
for direction in ["vertical", "horizontal"]:
for t in self.t_bbox[direction]:
indices, error = get_table_index(
table,
t,
direction,
split_text=self.split_text,
flag_size=self.flag_size,
strip_text=self.strip_text,
)
if indices[:2] != (-1, -1):
pos_errors.append(error)
indices = Lattice._reduce_index(
table, indices, shift_text=self.shift_text
)
for r_idx, c_idx, text in indices:
temp_text=text.strip().replace("\n","")
if len(temp_text)==1:
text=temp_text
table.cells[r_idx][c_idx].text = text
accuracy = compute_accuracy([[100, pos_errors]])
if self.copy_text is not None:
table = Lattice._copy_spanning_text(table, copy_text=self.copy_text)
data = table.data
table.df = pd.DataFrame(data)
table.shape = table.df.shape
whitespace = compute_whitespace(data)
table.flavor = "lattice"
table.accuracy = accuracy
table.whitespace = whitespace
table.order = table_idx + 1
table.page = int(os.path.basename(self.rootname).replace("page-", ""))
# for plotting
_text = []
_text.extend([(t.x0, t.y0, t.x1, t.y1) for t in self.horizontal_text])
_text.extend([(t.x0, t.y0, t.x1, t.y1) for t in self.vertical_text])
table._text = _text
table._image = (self.image, self.table_bbox_unscaled)
table._segments = (self.vertical_segments, self.horizontal_segments)
table._textedges = None
return table
class PDFHandler(handlers.PDFHandler):
def parse(
self, flavor="lattice", suppress_stdout=False, layout_kwargs={}, **kwargs
):
tables = []
with TemporaryDirectory() as tempdir:
for p in self.pages:
self._save_page(self.filepath, p, tempdir)
pages = [
os.path.join(tempdir, f"page-{p}.pdf") for p in self.pages
]
parser = custom_lattice(**kwargs) if flavor == "lattice" else Stream(**kwargs)
for p in pages:
t = parser.extract_tables(
p, suppress_stdout=suppress_stdout, layout_kwargs=layout_kwargs
)
tables.extend(t)
return TableList(sorted(tables))
def read_pdf_custom(
filepath,
pages="1",
password=None,
flavor="lattice",
suppress_stdout=False,
layout_kwargs={},
**kwargs
):
if flavor not in ["lattice", "stream"]:
raise NotImplementedError(
"Unknown flavor specified." " Use either 'lattice' or 'stream'"
)
with warnings.catch_warnings():
if suppress_stdout:
warnings.simplefilter("ignore")
validate_input(kwargs, flavor=flavor)
p = PDFHandler(filepath, pages=pages, password=password)
kwargs = remove_extra(kwargs, flavor=flavor)
tables = p.parse(
flavor=flavor,
suppress_stdout=suppress_stdout,
layout_kwargs=layout_kwargs,
**kwargs
)
return tables
I also posted in the pyqtgraph forum here.
My overall goal is to have several clickable regions overlaid on an image, and if the plot boundary of any region is clicked I get a signal with the ID of that region. Something like this:
If I use only one PlotDataItem with nan-separated curves then each boundary sends the same signal. However, using a separate PlotDataItem for each boundary makes the application extremely sluggish.
I ended up subclassing ScatterPlotItem and rewriting the pointsAt function, which does what I want. The problem now is I can't figure out the appropriate way to change the ScatterPlotItem's boundingRect. Am I on the right approach? Is there a better way of doing this?
import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore, QtGui
class CustScatter(pg.ScatterPlotItem):
def pointsAt(self, pos: QtCore.QPointF):
"""
The default implementation only checks a square around each spot. However, this is not
precise enough for my needs. It also triggers when clicking *inside* the spot boundary,
which I don't want.
"""
pts = []
for spot in self.points(): # type: pg.SpotItem
symb = QtGui.QPainterPath(spot.symbol())
symb.translate(spot.pos())
stroker = QtGui.QPainterPathStroker()
mousePath = stroker.createStroke(symb)
# Only trigger when clicking a boundary, not the inside of the shape
if mousePath.contains(pos):
pts.append(spot)
return pts[::-1]
"""Make some sample data"""
tri = np.array([[0,2.3,0,1,4,5,0], [0,4,4,8,8,3,0]]).T
tris = []
xyLocs = []
datas = []
for ii in np.arange(0, 16, 5):
curTri = tri + ii
tris.append(curTri)
xyLocs.append(curTri.min(0))
datas.append(ii)
def ptsClicked(item, pts):
print(f'ID {pts[0].data()} Clicked!')
"""Logic for making spot shapes from a list of (x,y) vertices"""
def makeSymbol(verts: np.ndarray):
outSymbol = QtGui.QPainterPath()
symPath = pg.arrayToQPath(*verts.T)
outSymbol.addPath(symPath)
# From pyqtgraph.examples for plotting text
br = outSymbol.boundingRect()
tr = QtGui.QTransform()
tr.translate(-br.x(), -br.y())
outSymbol = tr.map(outSymbol)
return outSymbol
app = pg.mkQApp()
pg.setConfigOption('background', 'w')
symbs = []
for xyLoc, tri in zip(xyLocs, tris):
symbs.append(makeSymbol(tri))
"""Create the scatterplot"""
xyLocs = np.vstack(xyLocs)
tri2 = pg.PlotDataItem()
scat = CustScatter(*xyLocs.T, symbol=symbs, data=datas, connect='finite',
pxMode=False, brush=None, pen=pg.mkPen(width=5), size=1)
scat.sigClicked.connect(ptsClicked)
# Now each 'point' is one of the triangles, hopefully
"""Construct GUI window"""
w = pg.PlotWindow()
w.plotItem.addItem(scat)
plt: pg.PlotItem = w.plotItem
plt.showGrid(True, True, 1)
w.show()
app.exec()
Solved! It turns out unless you specify otherwise, the boundingRect of each symbol in the dataset is assumed to be 1 and that the spot size is the limiting factor. After overriding measureSpotSizes as well my solution works:
import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore, QtGui
class CustScatter(pg.ScatterPlotItem):
def pointsAt(self, pos: QtCore.QPointF):
"""
The default implementation only checks a square around each spot. However, this is not
precise enough for my needs. It also triggers when clicking *inside* the spot boundary,
which I don't want.
"""
pts = []
for spot in self.points(): # type: pg.SpotItem
symb = QtGui.QPainterPath(spot.symbol())
symb.translate(spot.pos())
stroker = QtGui.QPainterPathStroker()
mousePath = stroker.createStroke(symb)
# Only trigger when clicking a boundary, not the inside of the shape
if mousePath.contains(pos):
pts.append(spot)
return pts[::-1]
def measureSpotSizes(self, dataSet):
"""
Override the method so that it takes symbol size into account
"""
for rec in dataSet:
## keep track of the maximum spot size and pixel size
symbol, size, pen, brush = self.getSpotOpts(rec)
br = symbol.boundingRect()
size = max(br.width(), br.height())*2
width = 0
pxWidth = 0
if self.opts['pxMode']:
pxWidth = size + pen.widthF()
else:
width = size
if pen.isCosmetic():
pxWidth += pen.widthF()
else:
width += pen.widthF()
self._maxSpotWidth = max(self._maxSpotWidth, width)
self._maxSpotPxWidth = max(self._maxSpotPxWidth, pxWidth)
self.bounds = [None, None]
"""Make some sample data"""
tri = np.array([[0,2.3,0,1,4,5,0], [0,4,4,8,8,3,0]]).T
tris = []
xyLocs = []
datas = []
for ii in np.arange(0, 16, 5):
curTri = tri + ii
tris.append(curTri)
xyLocs.append(curTri.min(0))
datas.append(ii)
def ptsClicked(item, pts):
print(f'ID {pts[0].data()} Clicked!')
"""Logic for making spot shapes from a list of (x,y) vertices"""
def makeSymbol(verts: np.ndarray):
plotVerts = verts - verts.min(0, keepdims=True)
symPath = pg.arrayToQPath(*plotVerts.T)
return symPath
app = pg.mkQApp()
pg.setConfigOption('background', 'd')
symbs = []
for xyLoc, tri in zip(xyLocs, tris):
symbs.append(makeSymbol(tri))
"""Create the scatterplot"""
xyLocs = np.vstack(xyLocs)
tri2 = pg.PlotDataItem()
scat = CustScatter(*xyLocs.T, symbol=symbs, data=datas, connect='finite',
pxMode=False, brush=None, pen=pg.mkPen(width=5), size=1)
scat.sigClicked.connect(ptsClicked)
# Now each 'point' is one of the triangles, hopefully
"""Construct GUI window"""
w = pg.PlotWindow()
w.plotItem.addItem(scat)
plt: pg.PlotItem = w.plotItem
plt.showGrid(True, True, 1)
w.show()
app.exec()
A list helped me make a loop which grabs a part of the screen and compare it to a existing RGB value. Final code is in final edit. Full learnprocess is shown below.
I'm new to coding in general. I am using pixels do verify RGB value of a pixel which declares position of player. But in stead of re-typing location = ImageGrab.grab(position) can I change the position?
seat0 = (658,848,660,850) #btn
seat1 = (428,816,429,817) #co
seat2 = (406,700,407,701) #hj
seat3 = (546,653,547,654) #lj
seat4 = (798,705,799,706) #bb
seat5 = (769,816,770,817) #sb
btnpos = ImageGrab.grab(seat0)
position = btnpos.load()
btn = (251,0,8)
if position[0,0] == btn:
print('Button')
elif ????
So I want to change seat0 (in imagegrab) to seat1 without re-typing the code. Thanks for the help
EDIT1: So I tried Marks suggestion, not sure what I'm doing wrong.
positie = ["(658,848,660,850)","(428,816,430,818)","(406,700,408,702)","(546,653,548,655)","(798,705,799,706)","(769,816,770,817)"]
btn = (251,0,8)
btnpos = ImageGrab.grab(positie[0])
btncheck = btnpos.load()
btn [0,0] not in positie[0]
I get "ValueError: too many values to unpack (expected 4)"
I am trying to RGB value of positie[0] and it has to match btn. if it doesn't match to btn then it has to check positie[1], until it returns true.
Edit2:
Thanks so far Mark. I am so bad at loops.
positie = [(658,848,660,850),(428,816,430,818),(406,700,408,702),(546,653,548,655),(798,705,799,706),(769,816,770,817)]
btn = (251,0,8)
btnpos = ImageGrab.grab(positie[0])
btncheck = btnpos.load()
for x in positie:
#btncheck = btnpos.load()
if btncheck[0,0] ==btn:
print("positie is BTN")
else:
btnpos = (positie[++1])
print(btnpos)
how do I change btnpos with adding up instead of re-typing everything?
So it's supposed to grab first item in list and compare it to btn. If it's correct it can break, if it's wrong it has to compare btn with second item in list.
If true it gives me the position. If false:
Now it prints list[1] six times.
Thank you for your patience
FINAL EDIT: Thank you Mark.
positie = [(658,848,660,850),(428,816,430,818),(406,700,408,702),(546,653,548,655),(798,705,799,706),(769,816,770,817)]
btn = (251,0,8)
counter = 0
max_index = len(positie) - 1
while counter <= max_index:
btnpos = ImageGrab.grab(positie[counter])
btncheck = btnpos.load()
if btncheck[0,0] == btn:
print("Positie")
break
else:
print("controleert de volgende")
counter = counter + 1
I want to select features and to zoom on them and do all these steps using PyQgis.
And I'm able to do both of them separatly but it doesn't seems to work when I try to mix the two of them.
Both of the codes I use for them are from the internet. Here's what I use to select features of a layer :
from qgis.core import *
import qgis.utils
lyrMap = QgsVectorLayer('C:/someplace', 'MapName', 'ogr')
QgsMapLayerRegistry.instance().addMapLayer(lyrMap)
expr = QgsExpression("'Attribute' IS NOT NULL")
it = lyrMap.getFeatures(QgsFeatureRequest(expr))
ids = [i.id() for i in it] #select only the features for which the expression is true
lyrMap.setSelectedFeatures(ids)
And it seems to do the trick as features appear selected on QGis.
In order to zoom the code is much more simple, it's just :
canvas = qgis.utils.iface.mapCanvas()
canvas.zoomToSelected(lyrMap)
But it seems that canvas doesn't consider that there's a selection on lyrMap and simply do nothing. I've tried to do the selection manually in QGis, and then zoom using zoomToSelected, and it worked.
But my objective is to do it without needing to do the selection manually...
Note : I don't think that's the issue, but the attribute I'm doing the selection on is from a join between lyrMap and another layer (I didn't put the code here because I don't think it's linked).
Thanks in advances for answers, clues or anything really :) !
This is working for my plugin. I am using python 2.7 and QGIS 1.8 and 2.0.1.You can use this code after including using vector file and adding it to the registry.
self.rubberBand = None
#create vertex marker for point..older versons..
self.vMarker = None
#add rubberbands
self.crossRb = QgsRubberBand(iface.mapCanvas(),QGis.Line)
self.crossRb.setColor(Qt.black)
def pan(self):
print "pan button clicked!"
x = self.dlg.ui.mTxtX.text()
y = self.dlg.ui.mTxtY.text()
if not x:
return
if not y:
return
print x + "," + y
canvas = self.canvas
currExt = canvas.extent()
canvasCenter = currExt.center()
dx = float(x) - canvasCenter.x()
dy = float(y) - canvasCenter.y()
xMin = currExt.xMinimum() + dx
xMax = currExt.xMaximum() + dx
yMin = currExt.yMinimum() + dy
yMax = currExt.yMaximum() + dy
newRect = QgsRectangle(xMin,yMin,xMax,yMax)
canvas.setExtent(newRect)
pt = QgsPoint(float(x),float(y))
self.zoom(pt)
canvas.refresh()
def zoom(self,point):
canvas = self.canvas
currExt = canvas.extent()
leftPt = QgsPoint(currExt.xMinimum(),point.y())
rightPt = QgsPoint(currExt.xMaximum(),point.y())
topPt = QgsPoint(point.x(),currExt.yMaximum())
bottomPt = QgsPoint(point.x(),currExt.yMinimum())
horizLine = QgsGeometry.fromPolyline( [ leftPt , rightPt ] )
vertLine = QgsGeometry.fromPolyline( [ topPt , bottomPt ] )
self.crossRb.reset(QGis.Line)
self.crossRb.addGeometry(horizLine,None)
self.crossRb.addGeometry(vertLine,None)
if QGis.QGIS_VERSION_INT >= 10900:
rb = self.rubberBand
rb.reset(QGis.Point)
rb.addPoint(point)
else:
self.vMarker = QgsVertexMarker(self.canvas)
self.vMarker.setIconSize(10)
self.vMarker.setCenter(point)
self.vMarker.show()
# wait .5 seconds to simulate a flashing effect
QTimer.singleShot(500,self.resetRubberbands)
def resetRubberbands(self):
print "resetting rubberbands.."
canvas = self.canvas
if QGis.QGIS_VERSION_INT >= 10900:
self.rubberBand.reset()
else:
self.vMarker.hide()
canvas.scene().removeItem(self.vMarker)
self.crossRb.reset()
print "completed resetting.."
Using the minimal example below, the line plot of a large (some 110k points) plot I get (with python 2.7, numpy 1.5.1, chaco/enable/traits 4.3.0) is this:
However, that is bizarre, because it is a line plot, and there shouldn't be any filled areas in there? Especially since the data is sawtooth-ish signal? It's as if there is a line at y~=37XX, above which there is color filling?! But sure enough, if I zoom into an area, I get the rendering I expect - without the unexpected fill:
Is this a bug - or is there something I'm doing wrong? I tried to use use_downsampling, but it makes no difference...
The test code:
import numpy as np
import numpy.random as npr
from pprint import pprint
from traits.api import HasTraits, Instance
from chaco.api import Plot, ArrayPlotData, VPlotContainer
from traitsui.api import View, Item
from enable.component_editor import ComponentEditor
from chaco.tools.api import PanTool, BetterSelectingZoom
tlen = 112607
alr = npr.randint(0, 4000, tlen)
tx = np.arange(0.0, 30.0-0.00001, 30.0/tlen)
ty = np.arange(0, tlen, 1) % 10000 + alr
pprint(len(ty))
class ChacoTest(HasTraits):
container = Instance(VPlotContainer)
traits_view = View(
Item('container', editor=ComponentEditor(), show_label=False),
width=800, height=500, resizable=True,
title="Chaco Test"
)
def __init__(self):
super(ChacoTest, self).__init__()
pprint(ty)
self.plotdata = ArrayPlotData(x = tx, y = ty)
self.plotobj = Plot(self.plotdata)
self.plotA = self.plotobj.plot(("x", "y"), type="line", color=(0,0.99,0), spacing=0, padding=0, alpha=0.7, use_downsampling=True)
self.container = VPlotContainer(self.plotobj, spacing=5, padding=5, bgcolor="lightgray")
#~ container.add(plot)
self.plotobj.tools.append(PanTool(self.plotobj))
self.plotobj.overlays.append(BetterSelectingZoom(self.plotobj))
if __name__ == "__main__":
ChacoTest().configure_traits()
I am able to reproduce the error and talking with John Wiggins (maintainer of Enable), it is a bug in kiva (which chaco uses to paint on the screen):
https://github.com/enthought/enable
The good news is that this is a bug in one of the kiva backend that you can use. So to go around the issue, you can run your script choosing a different backend:
ETS_TOOLKIT=qt4.qpainter python <NAME OF YOUR SCRIPT>
if you use qpainter or quartz, the plot looks (on my machine) as expected. If you choose qt4.image (the Agg backend), you will reproduce the issue. Unfortunately, the Agg backend is the default one. To change that, you can set the ETS_TOOLKIT environment variable to that value:
export ETS_TOOLKIT=qt4.qpainter
The bad news is that fixing this isn't going to be an easy task. Please feel free to report the bug in github (again https://github.com/enthought/enable) if you want to be involved in this. If you don't, I will log it in the next couple of days. Thanks for reporting it!
Just a note - I found this:
[Enthought-Dev] is chaco faster than matplotlib
I recall reading somewhere that you are expected to implement the
_downsample method because the optimal algorithm depends on the type
of data you're collecting.
And as I couldn't find any examples with _downsample implementation other than decimated_plot.py referred in that post, which isn't standalone - I tried and built a standalone example, included below.
The example basically has messed up drag and zoom, (plot disappears if you go out of range, or stretches upon a drag move) - and it starts zoomed in; but it is possible to zoom it out in the range shown in the OP - and then it displays the exact same plot rendering problem. So downsampling isn't the solution per se, so this is likely a bug?
import numpy as np
import numpy.random as npr
from pprint import pprint
from traits.api import HasTraits, Instance
from chaco.api import Plot, ArrayPlotData, VPlotContainer
from traitsui.api import View, Item
from enable.component_editor import ComponentEditor
from chaco.tools.api import PanTool, BetterSelectingZoom
#
from chaco.api import BaseXYPlot, LinearMapper, AbstractPlotData
from enable.api import black_color_trait, LineStyle
from traits.api import Float, Enum, Int, Str, Trait, Event, Property, Array, cached_property, Bool, Dict
from chaco.abstract_mapper import AbstractMapper
from chaco.abstract_data_source import AbstractDataSource
from chaco.array_data_source import ArrayDataSource
from chaco.data_range_1d import DataRange1D
tlen = 112607
alr = npr.randint(0, 4000, tlen)
tx = np.arange(0.0, 30.0-0.00001, 30.0/tlen)
ty = np.arange(0, tlen, 1) % 10000 + alr
pprint(len(ty))
class ChacoTest(HasTraits):
container = Instance(VPlotContainer)
traits_view = View(
Item('container', editor=ComponentEditor(), show_label=False),
width=800, height=500, resizable=True,
title="Chaco Test"
)
downsampling_cutoff = Int(4)
def __init__(self):
super(ChacoTest, self).__init__()
pprint(ty)
self.plotdata = ArrayPlotData(x = tx, y = ty)
self.plotobj = TimeSeriesPlot(self.plotdata)
self.plotobj.setplotranges("x", "y")
self.container = VPlotContainer(self.plotobj, spacing=5, padding=5, bgcolor="lightgray")
self.plotobj.tools.append(PanTool(self.plotobj))
self.plotobj.overlays.append(BetterSelectingZoom(self.plotobj))
# decimate from:
# https://bitbucket.org/mjrosen/neurobehavior/raw/097ef3719d1263a8b303d29c31ab71b6e792ab04/cns/widgets/views/decimated_plot.py
def decimate(data, screen_width, downsampling_cutoff=4, mode='extremes'):
data_width = data.shape[-1]
downsample = np.floor((data_width/screen_width)/4.)
if downsample > downsampling_cutoff:
return globals()['decimate_'+mode](data, downsample)
else:
return data
def decimate_extremes(data, downsample):
last_dim = data.ndim
offset = data.shape[-1] % downsample
if data.ndim == 2:
shape = (len(data), -1, downsample)
else:
shape = (-1, downsample)
data = data[..., offset:].reshape(shape).copy()
data_min = data.min(last_dim)
data_max = data.max(last_dim)
return data_min, data_max
def decimate_mean(data, downsample):
offset = len(data) % downsample
if data.ndim == 2:
shape = (-1, downsample, data.shape[-1])
else:
shape = (-1, downsample)
data = data[offset:].reshape(shape).copy()
return data.mean(1)
# based on class from decimated_plot.py, also
# neurobehavior/cns/chaco_exts/timeseries_plot.py ;
# + some other code from chaco
class TimeSeriesPlot(BaseXYPlot):
color = black_color_trait
line_width = Float(1.0)
line_style = LineStyle
reference = Enum('most_recent', 'trigger')
traits_view = View("color#", "line_width")
downsampling_cutoff = Int(100)
signal_trait = "updated"
decimate_mode = Str('extremes')
ch_index = Trait(None, Int, None)
# Mapping of data names from self.data to their respective datasources.
datasources = Dict(Str, Instance(AbstractDataSource))
index_mapper = Instance(AbstractMapper)
value_mapper = Instance(AbstractMapper)
def __init__(self, data=None, **kwargs):
super(TimeSeriesPlot, self).__init__(**kwargs)
self._index_mapper_changed(None, self.index_mapper)
self.setplotdata(data)
self._plot_ui_info = None
return
def setplotdata(self, data):
if data is not None:
if isinstance(data, AbstractPlotData):
self.data = data
elif type(data) in (ndarray, tuple, list):
self.data = ArrayPlotData(data)
else:
raise ValueError, "Don't know how to create PlotData for data" \
"of type " + str(type(data))
def setplotranges(self, index_name, value_name):
self.index_name = index_name
self.value_name = value_name
index = self._get_or_create_datasource(index_name)
value = self._get_or_create_datasource(value_name)
if not(self.index_mapper):
imap = LinearMapper()#(range=self.index_range)
self.index_mapper = imap
if not(self.value_mapper):
vmap = LinearMapper()#(range=self.value_range)
self.value_mapper = vmap
if not(self.index_range): self.index_range = DataRange1D() # calls index_mapper
if not(self.value_range): self.value_range = DataRange1D()
self.index_range.add(index) # calls index_mapper!
self.value_range.add(value)
# now do it (right?):
self.index_mapper = LinearMapper(range=self.index_range)
self.value_mapper = LinearMapper(range=self.value_range)
def _get_or_create_datasource(self, name):
if name not in self.datasources:
data = self.data.get_data(name)
if type(data) in (list, tuple):
data = array(data)
if isinstance(data, np.ndarray):
if len(data.shape) == 1:
ds = ArrayDataSource(data, sort_order="none")
elif len(data.shape) == 2:
ds = ImageData(data=data, value_depth=1)
elif len(data.shape) == 3:
if data.shape[2] in (3,4):
ds = ImageData(data=data, value_depth=int(data.shape[2]))
else:
raise ValueError("Unhandled array shape in creating new plot: " \
+ str(data.shape))
elif isinstance(data, AbstractDataSource):
ds = data
else:
raise ValueError("Couldn't create datasource for data of type " + \
str(type(data)))
self.datasources[name] = ds
return self.datasources[name]
def get_screen_points(self):
self._gather_points()
return self._downsample()
def _data_changed(self):
self.invalidate_draw()
self._cache_valid = False
self._screen_cache_valid = False
self.request_redraw()
def _gather_points(self):
if not self._cache_valid:
range = self.index_mapper.range
#if self.reference == 'most_recent':
# values, t_lb, t_ub = self.get_recent_range(range.low, range.high)
#else:
# values, t_lb, t_ub = self.get_range(range.low, range.high, -1)
values, t_lb, t_ub = self.data[self.value_name][range.low:range.high], range.low, range.high
#if self.ch_index is None:
# self._cached_data = values
#else:
# #self._cached_data = values[:,self.ch_index]
self._cached_data = values
self._cached_data_bounds = t_lb, t_ub
self._cache_valid = True
self._screen_cache_valid = False
def _downsample(self):
if not self._screen_cache_valid:
val_pts = self._cached_data
screen_min, screen_max = self.index_mapper.screen_bounds
screen_width = screen_max-screen_min
values = decimate(val_pts, screen_width, self.downsampling_cutoff,
self.decimate_mode)
if type(values) == type(()):
n = len(values[0])
s_val_min = self.value_mapper.map_screen(values[0])
s_val_max = self.value_mapper.map_screen(values[1])
self._cached_screen_data = s_val_min, s_val_max
else:
s_val_pts = self.value_mapper.map_screen(values)
self._cached_screen_data = s_val_pts
n = len(values)
t = np.linspace(*self._cached_data_bounds, num=n)
t_screen = self.index_mapper.map_screen(t)
self._cached_screen_index = t_screen
self._screen_cache_valid = True
return [self._cached_screen_index, self._cached_screen_data]
def _render(self, gc, points):
idx, val = points
if len(idx) == 0:
return
gc.save_state()
gc.set_antialias(True)
gc.clip_to_rect(self.x, self.y, self.width, self.height)
gc.set_stroke_color(self.color_)
gc.set_line_width(self.line_width)
#gc.set_line_width(5)
gc.begin_path()
#if len(val) == 2:
if type(val) == type(()):
starts = np.column_stack((idx, val[0]))
ends = np.column_stack((idx, val[1]))
gc.line_set(starts, ends)
else:
gc.lines(np.column_stack((idx, val)))
gc.stroke_path()
self._draw_default_axes(gc)
gc.restore_state()
if __name__ == "__main__":
ChacoTest().configure_traits()