Making a 3D model in VTK solid instead of hollow inside - python

I'm trying to create 3D model of the skull using VTK [example]:(https://kitware.github.io/vtk-examples/site/Python/VisualizationAlgorithms/HeadBone/)
#!/usr/bin/env python
# noinspection PyUnresolvedReferences
import vtkmodules.vtkInteractionStyle
# noinspection PyUnresolvedReferences
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkCommonCore import (
VTK_VERSION_NUMBER,
vtkVersion
)
from vtkmodules.vtkCommonDataModel import vtkMergePoints
from vtkmodules.vtkFiltersCore import (
vtkFlyingEdges3D,
vtkMarchingCubes
)
from vtkmodules.vtkFiltersModeling import vtkOutlineFilter
from vtkmodules.vtkIOImage import vtkMetaImageReader
from vtkmodules.vtkRenderingCore import (
vtkActor,
vtkPolyDataMapper,
vtkRenderWindow,
vtkRenderWindowInteractor,
vtkRenderer
)
def main():
# vtkFlyingEdges3D was introduced in VTK >= 8.2
use_flying_edges = vtk_version_ok(8, 2, 0)
file_name = get_program_parameters()
colors = vtkNamedColors()
# Create the RenderWindow, Renderer and Interactor.
ren = vtkRenderer()
ren_win = vtkRenderWindow()
ren_win.AddRenderer(ren)
iren = vtkRenderWindowInteractor()
iren.SetRenderWindow(ren_win)
# Create the pipeline.
reader = vtkMetaImageReader()
reader.SetFileName(file_name)
reader.Update()
locator = vtkMergePoints()
locator.SetDivisions(64, 64, 92)
locator.SetNumberOfPointsPerBucket(2)
locator.AutomaticOff()
if use_flying_edges:
try:
using_marching_cubes = False
iso = vtkFlyingEdges3D()
except AttributeError:
using_marching_cubes = True
iso = vtkMarchingCubes()
else:
using_marching_cubes = True
iso = vtkMarchingCubes()
iso.SetInputConnection(reader.GetOutputPort())
iso.ComputeGradientsOn()
iso.ComputeScalarsOff()
iso.SetValue(0, 1150)
if using_marching_cubes:
iso.SetLocator(locator)
iso_mapper = vtkPolyDataMapper()
iso_mapper.SetInputConnection(iso.GetOutputPort())
iso_mapper.ScalarVisibilityOff()
iso_actor = vtkActor()
iso_actor.SetMapper(iso_mapper)
iso_actor.GetProperty().SetColor(colors.GetColor3d('Ivory'))
outline = vtkOutlineFilter()
outline.SetInputConnection(reader.GetOutputPort())
outline_mapper = vtkPolyDataMapper()
outline_mapper.SetInputConnection(outline.GetOutputPort())
outline_actor = vtkActor()
outline_actor.SetMapper(outline_mapper)
# Add the actors to the renderer, set the background and size.
#
ren.AddActor(outline_actor)
ren.AddActor(iso_actor)
ren.SetBackground(colors.GetColor3d('SlateGray'))
ren.GetActiveCamera().SetFocalPoint(0, 0, 0)
ren.GetActiveCamera().SetPosition(0, -1, 0)
ren.GetActiveCamera().SetViewUp(0, 0, -1)
ren.ResetCamera()
ren.GetActiveCamera().Dolly(1.5)
ren.ResetCameraClippingRange()
ren_win.SetSize(640, 480)
ren_win.SetWindowName('HeadBone')
ren_win.Render()
iren.Start()
def get_program_parameters():
import argparse
description = 'Marching cubes surface of human bone.'
epilogue = '''
'''
parser = argparse.ArgumentParser(description=description, epilog=epilogue,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('filename', help='FullHead.mhd.')
args = parser.parse_args()
return args.filename
def vtk_version_ok(major, minor, build):
"""
Check the VTK version.
:param major: Major version.
:param minor: Minor version.
:param build: Build version.
:return: True if the requested VTK version is greater or equal to the actual VTK version.
"""
needed_version = 10000000000 * int(major) + 100000000 * int(minor) + int(build)
try:
vtk_version_number = VTK_VERSION_NUMBER
except AttributeError: # as error:
ver = vtkVersion()
vtk_version_number = 10000000000 * ver.GetVTKMajorVersion() + 100000000 * ver.GetVTKMinorVersion() \
+ ver.GetVTKBuildVersion()
if vtk_version_number >= needed_version:
return True
else:
return False
if __name__ == '__main__':
main()
It appears that model works just fine, but there is something that I want to know. The model of the skull works perfectly fine, but the model density is not solid (it hollow inside). It generates only the surface of the model.
I want to know how I can fill the gap in the surface to get the full solid model.

In this example, you are trying to extract the isosurface (hollow) from the data.
vtkFlyingEdges3D() and vtkMarchingCubes() algorithms will create contours based on the iso.SetValue(0, 1150) and extract the surface for you. If you want to keep it filled, remove the contouring functions, i'e, vtkFlyingEdges3D() and vtkMarchingCubes() from your script and use some image data mapper instead of vtkPolydataMapper, it will show the entire object.

Related

Cannot visualize the points around one of the selected points in point cloud. Not sure what I am missing

I used this code. What it does is just takes a pointcloud and convert it into polydata and visualize it. After visualizing the pointcloud, I use mouse to select one point and calculate points around that particular point in a given radius. I am able to calculate those points but not able to visualize those. Also moving further I select 1000 random points in data and calculate their clusters too, calculate Hausdorff distances of each with original cluster and visualize the minimum distance cluster. In all this I am not able to visualize any cluster.`
import vtk
import numpy as np
import open3d as o3d
import math
from numpy import random
from scipy.spatial.distance import directed_hausdorff
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkCommonCore import vtkIdTypeArray
from vtkmodules.vtkCommonDataModel import (
vtkSelection,
vtkSelectionNode,
vtkUnstructuredGrid
)
from vtkmodules.vtkFiltersCore import vtkTriangleFilter
from vtkmodules.vtkFiltersExtraction import vtkExtractSelection
from vtkmodules.vtkFiltersSources import vtkPlaneSource
from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
from vtkmodules.vtkRenderingCore import (
vtkActor,
vtkCellPicker,
vtkDataSetMapper,
vtkPolyDataMapper,
vtkRenderWindow,
vtkRenderWindowInteractor,
vtkRenderer
)
class VtkPointCloud:
def __init__(self, zMin=-10.0, zMax=10.0, maxNumPoints=2e6):
# c = vtkNamedColors()
self.maxNumPoints = maxNumPoints
self.vtkPolyData = vtk.vtkPolyData()
self.clearPoints()
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputData(self.vtkPolyData)
mapper.SetColorModeToDefault()
mapper.SetScalarRange(zMin, zMax)
mapper.SetScalarVisibility(1)
self.vtkActor = vtk.vtkActor()
# self.vtkActor.GetProperty().SetColor(1,1,1)
# self.vtkActor.GetProperty().SetColor(c.GetColor3d('Yellow'))
self.vtkActor.SetMapper(mapper)
def addPoint(self, point):
if self.vtkPoints.GetNumberOfPoints() < self.maxNumPoints:
pointId = self.vtkPoints.InsertNextPoint(point[:])
self.vtkDepth.InsertNextValue(point[2])
self.vtkCells.InsertNextCell(1)
self.vtkCells.InsertCellPoint(pointId)
else:
r = random.randint(0, self.maxNumPoints)
self.vtkPoints.SetPoint(r, point[:])
self.vtkCells.Modified()
self.vtkPoints.Modified()
self.vtkDepth.Modified()
def clearPoints(self):
self.vtkPoints = vtk.vtkPoints()
self.vtkCells = vtk.vtkCellArray()
self.vtkDepth = vtk.vtkDoubleArray()
self.vtkDepth.SetName('DepthArray')
self.vtkPolyData.SetPoints(self.vtkPoints)
self.vtkPolyData.SetVerts(self.vtkCells)
self.vtkPolyData.GetPointData().SetScalars(self.vtkDepth)
self.vtkPolyData.GetPointData().SetActiveScalars('DepthArray')
# Catch mouse events
class MouseInteractorStyle(vtkInteractorStyleTrackballCamera):
def __init__(self, data):
self.AddObserver('LeftButtonPressEvent', self.left_button_press_event)
self.AddObserver('RightButtonPressEvent', self.right_button_press_event)
self.data = data
self.selected_mapper = vtkDataSetMapper()
self.selected_actor = vtkActor()
self.selected_mapper2 = vtkDataSetMapper()
self.selected_actor2 = vtkActor()
self.vtk_list = vtk.vtkIdList()
self.locator = vtk.vtkPointLocator()
self.locator.SetDataSet(self.data)
self.locator.BuildLocator()
self.colors = vtkNamedColors()
def left_button_press_event(self, obj, event):
pos = self.GetInteractor().GetEventPosition()
picker = vtkCellPicker()
picker.SetTolerance(0.001)
# Pick from this location.
picker.Pick(pos[0], pos[1], 0, self.GetDefaultRenderer())
world_position = picker.GetPickPosition()
print(f'Cell id is: {picker.GetCellId()}')
# print(world_position)
self.locator.FindPointsWithinRadius(0.02,world_position, self.vtk_list)
print(self.vtk_list)
if picker.GetCellId() != -1:
print(f'Pick position is: ({world_position[0]:.6g}, {world_position[1]:.6g}, {world_position[2]:.6g})')
ids = vtkIdTypeArray()
ids.SetNumberOfComponents(1)
ids.InsertNextValue(picker.GetCellId())
# print(ids,'\n')
selection_node = vtkSelectionNode()
selection_node.SetFieldType(vtkSelectionNode.CELL)
selection_node.SetContentType(vtkSelectionNode.INDICES)
selection_node.SetSelectionList(ids)
selection = vtkSelection()
selection.AddNode(selection_node)
extract_selection = vtkExtractSelection()
extract_selection.SetInputData(0, self.data)
extract_selection.SetInputData(1, selection)
extract_selection.Update()
# In selection
selected = vtkUnstructuredGrid()
selected.ShallowCopy(extract_selection.GetOutput())
print(f'Number of points in the selection: {selected.GetNumberOfPoints()}')
# print(f'Number of cells in the selection : {selected.GetNumberOfCells()}\n')
print('########################\n')
self.selected_mapper.SetInputData(selected)
self.selected_actor.SetMapper(self.selected_mapper)
# self.selected_actor.GetProperty().EdgeVisibilityOn()
self.selected_actor.GetProperty().SetColor(self.colors.GetColor3d('Black'))
self.selected_actor.GetProperty().SetPointSize(10)
self.selected_actor.GetProperty().SetLineWidth(3)
# print(self.selected_actor)
self.GetInteractor().GetRenderWindow().GetRenderers().GetFirstRenderer().AddActor(self.selected_actor)
# Forward events
self.OnLeftButtonDown()
def right_button_press_event(self, obj, event):
if self.vtk_list.GetNumberOfIds() == 0:
return
else:
ids2 = vtkIdTypeArray()
ids2.SetNumberOfComponents(1)
for i in range(self.vtk_list.GetNumberOfIds()):
# print(i)
ids2.InsertNextValue(i)
# print(ids2)
selection_node2 = vtkSelectionNode()
selection_node2.SetFieldType(vtkSelectionNode.CELL)
selection_node2.SetContentType(vtkSelectionNode.INDICES)
selection_node2.SetSelectionList(ids2)
selection2 = vtkSelection()
selection2.AddNode(selection_node2)
extract_selection2 = vtkExtractSelection()
extract_selection2.SetInputData(0, self.data)
extract_selection2.SetInputData(1, selection2)
extract_selection2.Update()
# # In selection
selected2 = vtkUnstructuredGrid()
selected2.ShallowCopy(extract_selection2.GetOutput())
print(f'Number of neighboring points: {selected2.GetNumberOfPoints()}')
# # print(f'Number of neighboring cells: {selected2.GetNumberOfCells()}\n')
print('########################\n')
self.selected_mapper2.SetInputData(selected2)
self.selected_actor2.SetMapper(self.selected_mapper2)
# self.selected_actor.GetProperty().EdgeVisibilityOn()
self.selected_actor2.GetProperty().SetColor(self.colors.GetColor3d("tan"))
self.selected_actor2.GetProperty().SetPointSize(10)
self.selected_actor2.GetProperty().SetLineWidth(3)
# print(self.selected_actor2)
self.GetInteractor().GetRenderWindow().GetRenderers().GetFirstRenderer().AddActor(self.selected_actor2)
print('Randomly Selecting 1000 points in the data........')
point_indices = []
cluster_points = np.zeros((self.vtk_list.GetNumberOfIds(),3))
# print(cluster_points)
print('Calculating the clusters around the centers ......')
for i in range(self.vtk_list.GetNumberOfIds()):
point_indices.append(self.vtk_list.GetId(i))
cluster_points[i]=pointCloud.vtkPolyData.GetPoint(self.vtk_list.GetId(i))
point_indices = np.asarray(point_indices)
new_points= np.delete(points,point_indices, axis=0)
random_array = np.random.randint(0,new_points.shape[0],(1000))
min_haus = 1000000000.0
for i in range(random_array.shape[0]):
new_list=vtk.vtkIdList()
new_center = new_points[random_array[i]]
# print('new center to find cluster:',new_center)
self.locator.FindPointsWithinRadius(0.02,new_center, new_list)
new_cluster_points = np.zeros((new_list.GetNumberOfIds(),3))
for x in range(new_list.GetNumberOfIds()):
new_cluster_points[x]=pointCloud.vtkPolyData.GetPoint(new_list.GetId(x))
haus = directed_hausdorff(cluster_points,new_cluster_points)[0]
if haus<min_haus:
min_haus = haus
idx = random_array[i]
min_center = new_points[random_array[i]]
# print('haus:',haus)
print('min haus:',min_haus)
print('idx of min haus:',idx)
print('center of the min haus cluster:',min_center)
min_list = vtk.vtkIdList()
self.locator.FindPointsWithinRadius(0.02,min_center, min_list)
print('Min list for min center', min_list)
# # Forward events
self.OnRightButtonDown()
# Initialize point clouds
pointCloud = VtkPointCloud()
pointCloud2 = VtkPointCloud()
# print(type(pointCloud))
# Loading Point cloud using open3d
pt_cloud = o3d.io.read_point_cloud('fused_cloud_normal.ply')
points = np.asarray(pt_cloud.points)
pt_cloud2 = o3d.io.read_point_cloud('all_projected_pts.ply')
points2 = np.asarray(pt_cloud2.points)
# Adding the points into polydata
for row in points:
pointCloud.addPoint(row)
for row in points2:
pointCloud2.addPoint(row)
# Intialize actor
c = vtkNamedColors()
actor = pointCloud.vtkActor
actor.GetProperty().SetPointSize(10)
actor.GetProperty().SetColor(c.GetColor3d('Yellow'))
# actor.GetProperty().SetColor(0,0,0)
# Renderer
renderer = vtk.vtkRenderer()
renderer.AddActor(actor)
# renderer.AddActor(pointCloud2.vtkActor)
renderer.SetBackground(.1, .1, .4)
renderer.ResetCamera()
# Render Window
renderWindow = vtk.vtkRenderWindow()
renderWindow.AddRenderer(renderer)
style = MouseInteractorStyle(pointCloud.vtkPolyData)
style.SetDefaultRenderer(renderer)
# Interactor
renderWindowInteractor = vtk.vtkRenderWindowInteractor()
renderWindowInteractor.SetRenderWindow(renderWindow)
renderWindowInteractor.SetInteractorStyle(style)
# Begin Interaction
renderWindow.Render()
renderWindowInteractor.Start()
`

I want to process a live stream from youtube with opencv

I am posting this while using a translation.
I would like to use a live stream on YouTube to action recognition.
Action recognition
https://github.com/felixchenfy/Realtime-Action-Recognition
Read Youtube live stream with opencv
How to read Youtube live stream using openCV python?
Youtube URL
https://www.youtube.com/watch?v=DjdUEyjx8GM
We have programmed it with reference to the above.The program runs normally for a few hours, but after 4 or 5 hours, I get an error saying that communication from the remote host has been lost.Checking the youtube video, the delivery does not seem to be interrupted.
A partial code for videocapture can be found here.
url = "https://www.youtube.com/watch?v=DjdUEyjx8GM"
video = pafy.new(url)
best = video.getbest(preftype="mp4")
class ReadFromVideo(object):
def __init__(self, video_path, sample_interval=1):
''' A video reader class for reading video frames from video.
Arguments:
video_path
sample_interval {int}: sample every kth image.
''' assert isinstance(sample_interval, int) and sample_interval >= 1
self.cnt_imgs = 0
self._is_stoped = False
self._video = cv2.VideoCapture(best.url)
ret, image = self._video.read()
ret, image2 = self._video2.read()
self._next_image = image
self._next_image2 = image2
self._sample_interval = sample_interval
self._fps = self.get_fps()
if not self._fps >= 0.0001:
import warnings
warnings.warn("Invalid fps of video: {}".format(video_path))
def has_image(self):
return self._next_image is not None
def get_curr_video_time(self):
return 1.0 / self._fps * self.cnt_imgs
def read_image(self):
image = self._next_image
image2 = self._next_image2
for i in range(self._sample_interval):
if self._video.isOpened():
ret, frame = self._video.read()
self._next_image = frame
ret, frame2 = self._video2.read()
self._next_image2 = frame2
else:
self._next_image = None
self._next_image2 = None
break
self.cnt_imgs += 1
return image, image2
def stop(self):
self._video.release()
self._video2.release()
self._is_stoped = True
def __del__(self):
if not self._is_stoped:
self.stop()
def get_fps(self):
# Find OpenCV version
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
# With webcam get(CV_CAP_PROP_FPS) does not work.
# Let's see for ourselves.
# Get video properties
if int(major_ver) < 3:
fps = self._video.get(cv2.cv.CV_CAP_PROP_FPS)
else:
fps = self._video.get(cv2.CAP_PROP_FPS)
return fps
#!/usr/bin/env python
# coding: utf-8
'''
Test action recognition on
(1) a video, (2) a folder of images, (3) or web camera.
Input:
model: model/trained_classifier.pickle
Output:
result video: output/${video_name}/video.avi
result skeleton: output/${video_name}/skeleton_res/XXXXX.txt
visualization by cv2.imshow() in img_displayer
'''
The entire source code is here
'''
Example of usage:
(1) Test on video file:
python src/s5_test.py \
--model_path model/trained_classifier.pickle \
--data_type video \
--data_path data_test/exercise.avi \
--output_folder output
(2) Test on a folder of images:
python src/s5_test.py \
--model_path model/trained_classifier.pickle \
--data_type folder \
--data_path data_test/apple/ \
--output_folder output
(3) Test on web camera:
python src/s5_test.py \
--model_path model/trained_classifier.pickle \
--data_type webcam \
--data_path 0 \
--output_folder output
if True: # Include project path
import sys
import os
ROOT = os.path.dirname(os.path.abspath(__file__))+"/../"
CURR_PATH = os.path.dirname(os.path.abspath(__file__))+"/"
sys.path.append(ROOT)
import utils.lib_images_io as lib_images_io
import utils.lib_plot as lib_plot
import utils.lib_commons as lib_commons
from utils.lib_openpose import SkeletonDetector
from utils.lib_tracker import Tracker
from utils.lib_tracker import Tracker
from utils.lib_classifier import ClassifierOnlineTest
from utils.lib_classifier import * # Import all sklearn related libraries
def par(path): # Pre-Append ROOT to the path if it's not absolute
return ROOT + path if (path and path[0] != "/") else path
# -- Command-line input
def get_command_line_arguments():
def parse_args():
parser = argparse.ArgumentParser(
description="Test action recognition on \n"
"(1) a video, (2) a folder of images, (3) or web camera.")
parser.add_argument("-m", "--model_path", required=False,
default='model/trained_classifier.pickle')
parser.add_argument("-t", "--data_type", required=False, default='webcam',
choices=["video", "folder", "webcam"])
parser.add_argument("-p", "--data_path", required=False, default="",
help="path to a video file, or images folder, or webcam. \n"
"For video and folder, the path should be "
"absolute or relative to this project's root. "
"For webcam, either input an index or device name. ")
parser.add_argument("-o", "--output_folder", required=False, default='output/',
help="Which folder to save result to.")
args = parser.parse_args()
return args
args = parse_args()
if args.data_type != "webcam" and args.data_path and args.data_path[0] != "/":
# If the path is not absolute, then its relative to the ROOT.
args.data_path = ROOT + args.data_path
return args
def get_dst_folder_name(src_data_type, src_data_path):
''' Compute a output folder name based on data_type and data_path.
The final output of this script looks like this:
DST_FOLDER/folder_name/vidoe.avi
DST_FOLDER/folder_name/skeletons/XXXXX.txt
'''
assert(src_data_type in ["video", "folder", "webcam"])
if src_data_type == "video": # /root/data/video.avi --> video
folder_name = str("video")
elif src_data_type == "folder": # /root/data/video/ --> video
folder_name = src_data_path.rstrip("/").split("/")[-1]
elif src_data_type == "webcam":
# month-day-hour-minute-seconds, e.g.: 02-26-15-51-12
folder_name = lib_commons.get_time_string()
return folder_name
args = get_command_line_arguments()
SRC_DATA_TYPE = args.data_type
SRC_DATA_PATH = args.data_path
SRC_MODEL_PATH = args.model_path
DST_FOLDER_NAME = get_dst_folder_name(SRC_DATA_TYPE, SRC_DATA_PATH)
# -- Settings
cfg_all = lib_commons.read_yaml(ROOT + "config/config.yaml")
cfg = cfg_all["s5_test.py"]
CLASSES = np.array(cfg_all["classes"])
SKELETON_FILENAME_FORMAT = cfg_all["skeleton_filename_format"]
# Action recognition: number of frames used to extract features.
WINDOW_SIZE = int(cfg_all["features"]["window_size"])
# Output folder
DST_FOLDER = args.output_folder + "/" + DST_FOLDER_NAME + "/"
DST_SKELETON_FOLDER_NAME = cfg["output"]["skeleton_folder_name"]
DST_VIDEO_NAME = cfg["output"]["video_name"]
# framerate of output video.avi
DST_VIDEO_FPS = float(cfg["output"]["video_fps"])
# writer
fmt = cv2.VideoWriter_fourcc(*'MP4V')
fps = 10.0
size = (800, 480)
writer = cv2.VideoWriter('outtest.mp4', fmt, fps, size)
backup_time = 3600 # sec
time = 0 # init
now = datetime.datetime.now()
s_now = now.strftime('%Y-%m-%d-%H-%M-%S')
back_up = cv2.VideoWriter(f'{s_now}.mp4',fmt, fps, size)
# Video setttings
# If data_type is webcam, set the max frame rate.
SRC_WEBCAM_MAX_FPS = float(cfg["settings"]["source"]
["webcam_max_framerate"])
# If data_type is video, set the sampling interval.
# For example, if it's 3, then the video will be read 3 times faster.
SRC_VIDEO_SAMPLE_INTERVAL = int(cfg["settings"]["source"]
["video_sample_interval"])
# Openpose settings
OPENPOSE_MODEL = cfg["settings"]["openpose"]["model"]
OPENPOSE_IMG_SIZE = cfg["settings"]["openpose"]["img_size"]
# Display settings
img_disp_desired_rows = int(cfg["settings"]["display"]["desired_rows"])
# -- Function
def select_images_loader(src_data_type, src_data_path):
if src_data_type == "video":
images_loader = lib_images_io.ReadFromVideo(
src_data_path,
sample_interval=SRC_VIDEO_SAMPLE_INTERVAL)
elif src_data_type == "folder":
images_loader = lib_images_io.ReadFromFolder(
folder_path=src_data_path)
elif src_data_type == "webcam":
if src_data_path == "":
webcam_idx = 0
elif src_data_path.isdigit():
webcam_idx = int(src_data_path)
else:
webcam_idx = src_data_path
images_loader = lib_images_io.ReadFromWebcam(
SRC_WEBCAM_MAX_FPS, webcam_idx)
return images_loader
class MultiPersonClassifier(object):
''' This is a wrapper around ClassifierOnlineTest
for recognizing actions of multiple people.
'''
def __init__(self, model_path, classes):
self.dict_id2clf = {} # human id -> classifier of this person
# Define a function for creating classifier for new people.
self._create_classifier = lambda human_id: ClassifierOnlineTest(
model_path, classes, WINDOW_SIZE, human_id)
def classify(self, dict_id2skeleton):
''' Classify the action type of each skeleton in dict_id2skeleton '''
# Clear people not in view
old_ids = set(self.dict_id2clf)
cur_ids = set(dict_id2skeleton)
humans_not_in_view = list(old_ids - cur_ids)
for human in humans_not_in_view:
del self.dict_id2clf[human]
# Predict each person's action
id2label = {}
for id, skeleton in dict_id2skeleton.items():
if id not in self.dict_id2clf: # add this new person
self.dict_id2clf[id] = self._create_classifier(id)
classifier = self.dict_id2clf[id]
id2label[id] = classifier.predict(skeleton) # predict label
# print("\n\nPredicting label for human{}".format(id))
# print(" skeleton: {}".format(skeleton))
# print(" label: {}".format(id2label[id]))
return id2label
def get_classifier(self, id):
''' Get the classifier based on the person id.
Arguments:
id {int or "min"}
'''
if len(self.dict_id2clf) == 0:
return None
if id == 'min':
id = min(self.dict_id2clf.keys())
return self.dict_id2clf[id]
def remove_skeletons_with_few_joints(skeletons):
''' Remove bad skeletons before sending to the tracker '''
good_skeletons = []
for skeleton in skeletons:
px = skeleton[2:2+13*2:2]
py = skeleton[3:2+13*2:2]
num_valid_joints = len([x for x in px if x != 0])
num_leg_joints = len([x for x in px[-6:] if x != 0])
total_size = max(py) - min(py)
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# IF JOINTS ARE MISSING, TRY CHANGING THESE VALUES:
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if num_valid_joints >= 5 and total_size >= 0.1 and num_leg_joints >= 0:
# add this skeleton only when all requirements are satisfied
good_skeletons.append(skeleton)
return good_skeletons
def draw_result_img(img_disp, ith_img, humans, dict_id2skeleton,
skeleton_detector, multiperson_classifier):
''' Draw skeletons, labels, and prediction scores onto image for display '''
# Resize to a proper size for display
r, c = img_disp.shape[0:2]
desired_cols = int(1.0 * c * (img_disp_desired_rows / r))
img_disp = cv2.resize(img_disp,
dsize=(desired_cols, img_disp_desired_rows))
# Draw all people's skeleton
skeleton_detector.draw(img_disp, humans)
# Draw bounding box and label of each person
if len(dict_id2skeleton):
for id, label in dict_id2label.items():
skeleton = dict_id2skeleton[id]
# scale the y data back to original
skeleton[1::2] = skeleton[1::2] / scale_h
# print("Drawing skeleton: ", dict_id2skeleton[id], "with label:", label, ".")
lib_plot.draw_action_result(img_disp, id, skeleton, label)
# Add blank to the left for displaying prediction scores of each class
img_disp = lib_plot.add_white_region_to_left_of_image(img_disp)
cv2.putText(img_disp, "Frame:" + str(ith_img),
(20, 20), fontScale=1.5, fontFace=cv2.FONT_HERSHEY_PLAIN,
color=(0, 0, 0), thickness=2)
# Draw predicting score for only 1 person
if len(dict_id2skeleton):
classifier_of_a_person = multiperson_classifier.get_classifier(
id='min')
classifier_of_a_person.draw_scores_onto_image(img_disp)
return img_disp
# Time Display
cv2.putText(img2_disp, str(now_time.isoformat(timespec='seconds')),
(20, 20), fontScale=1.5, fontFace=cv2.FONT_HERSHEY_PLAIN,
color=(0, 0, 0), thickness=2)
# Draw predicting score for only 1 person
if len(dict2_id2skeleton):
classifier_of_a_person = multiperson_classifier.get_classifier(
id='min')
classifier_of_a_person.draw_scores_onto_image(img2_disp)
return img2_disp
def get_the_skeleton_data_to_save_to_disk(dict_id2skeleton):
'''
In each image, for each skeleton, save the:
human_id, label, and the skeleton positions of length 18*2.
So the total length per row is 2+36=38
'''
skels_to_save = []
for human_id in dict_id2skeleton.keys():
label = dict_id2label[human_id]
skeleton = dict_id2skeleton[human_id]
skels_to_save.append([[human_id, label] + skeleton.tolist()])
return skels_to_save
# -- Main
if __name__ == "__main__":
# -- Detector, tracker, classifier
skeleton_detector = SkeletonDetector(OPENPOSE_MODEL, OPENPOSE_IMG_SIZE)
multiperson_tracker = Tracker()
multiperson_classifier = MultiPersonClassifier(SRC_MODEL_PATH, CLASSES)
# -- Image reader and displayer
images_loader = select_images_loader(SRC_DATA_TYPE, SRC_DATA_PATH)
img_displayer = lib_images_io.ImageDisplayer()
# -- Init output
# output folder
os.makedirs(DST_FOLDER, exist_ok=True)
os.makedirs(DST_FOLDER + DST_SKELETON_FOLDER_NAME, exist_ok=True)
# -- Read images and process
try:
ith_img = -1
while images_loader.has_image():
# Timer
now = datetime.datetime.now()
now_time = now.time()
# -- Read image
#img = images_loader.read_image()
img = images_loader.read_image()
ith_img += 1
img_disp = img.copy()
# print(f"\nProcessing {ith_img}th image ...")
# -- Detect skeletons
humans = skeleton_detector.detect(img)
skeletons, scale_h = skeleton_detector.humans_to_skels_list(humans)
)
skeletons = remove_skeletons_with_few_joints(skeletons)
# -- Track people
dict_id2skeleton = multiperson_tracker.track(skeletons)
# -- Log & Recognize action of each person
log.append(str(now_time))
if len(dict_id2skeleton):
dict_id2label = multiperson_classifier.classify(
dict_id2skeleton)
# -- Draw
img_disp = draw_result_img(img_disp, ith_img, humans, dict_id2skeleton, skeleton_detector, multiperson_classifier)
# Print label of a person
if len(dict_id2skeleton):
min_id = min(dict_id2skeleton.keys())
# print("prediced label is :", dict_id2label[min_id])
# -- Display image, and write to video.avi
# video writer
if time >= backup_time*fps:
back_up.release()
now = datetime.datetime.now()
s_now = now.strftime('%Y-%m-%d-%H-%M-%S')
back_up = cv2.VideoWriter(f'{s_now}.mp4', fmt, fps, size)
time = 0
time += 1
writer.write(img_disp)
back_up.write(img_disp)
img_displayer.display(img_disp, wait_key_ms=1)
# -- Get skeleton data and save to file
skels_to_save = get_the_skeleton_data_to_save_to_disk(dict_id2skeleton)
lib_commons.save_listlist(DST_FOLDER + DST_SKELETON_FOLDER_NAME +SKELETON_FILENAME_FORMAT.format(ith_img),skels_to_save)
if cv2.waitKey(1) == 27:
break
finally:
writer.release()
back_up.release()
cv2.destroyAllWindows()
Thanks for your help.

Finding cells of a .stl file with negative mean curvature using VTK in python

I have a .stl file and i'm trying to find the coordinates of cells with negative mean curvature using VTK and python. I have wrote these codes which are working fine to change the colors of cells based on their mean curvature but what i'm willing to achieve is coordinates of exact cells and triangles with specific mean curvature, e.g. 3d coordinates of cells with most negative mean curvature.
Here are the codes:
import vtk
def gaussian_curve(fileNameSTL):
colors = vtk.vtkNamedColors()
reader = vtk.vtkSTLReader()
reader.SetFileName(fileNameSTL)
reader.Update()
curveGauss = vtk.vtkCurvatures()
curveGauss.SetInputConnection(reader.GetOutputPort())
curveGauss.SetCurvatureTypeToGaussian() # SetCurvatureTypeToMean() works better in the case of kidney.
ctf = vtk.vtkColorTransferFunction()
ctf.SetColorSpaceToDiverging()
p1 = [0.0] + list(colors.GetColor3d("MidnightBlue"))
p2 = [1.0] + list(colors.GetColor3d("DarkRed"))
ctf.AddRGBPoint(*p1)
ctf.AddRGBPoint(*p2)
cc = list()
for i in range(256):
cc.append(ctf.GetColor(float(i) / 255.0))
lut = vtk.vtkLookupTable()
lut.SetNumberOfColors(256)
for i, item in enumerate(cc):
lut.SetTableValue(i, item[0], item[1], item[2], 1.0)
lut.SetRange(0, 0) # In the case of kidney, the (0, 0) worked better.
lut.Build()
cmapper = vtk.vtkPolyDataMapper()
cmapper.SetInputConnection(curveGauss.GetOutputPort())
cmapper.SetLookupTable(lut)
cmapper.SetUseLookupTableScalarRange(1)
cActor = vtk.vtkActor()
cActor.SetMapper(cmapper)
return cActor
def render_scene(my_actor_list):
renderer = vtk.vtkRenderer()
for arg in my_actor_list:
renderer.AddActor(arg)
namedColors = vtk.vtkNamedColors()
renderer.SetBackground(namedColors.GetColor3d("SlateGray"))
window = vtk.vtkRenderWindow()
window.SetWindowName("Render Window")
window.AddRenderer(renderer)
interactor = vtk.vtkRenderWindowInteractor()
interactor.SetRenderWindow(window)
# Visualize
window.Render()
interactor.Start()
if __name__ == '__main__':
fileName = "400_tri.stl"
my_list = list()
my_list.append(gaussian_curve(fileName))
render_scene(my_list)
This code produce red cells for positive mean curvature and blue for negative ones.
I need the result(coordinates of cells) in the form of arrays or something like that.
I would appreciate any suggestion and help on this problem.
A possible solution with vtkplotter:
from vtkplotter import *
torus1 = Torus().addCurvatureScalars().addScalarBar()
print("list of scalars:", torus1.scalars())
torus2 = torus1.clone().addScalarBar()
torus2.threshold("Gauss_Curvature", vmin=-15, vmax=0)
show(torus1, torus2, N=2) # plot on 2 separate renderers
print("vertex coordinates:", len(torus2.coordinates()))
print("cell centers :", len(torus2.cellCenters()))
check out the resulting screenshot here
Additional example here.
Hope this helps.
So i found the answer from kitware weblog, here is the code that works fine using vtk.numpy_interface and vtk.util.numpy_support, but still it does not produce the normals_array and i don't know why??
import vtk
from vtk.numpy_interface import dataset_adapter as dsa
from vtk.util.numpy_support import vtk_to_numpy
def curvature_to_numpy(fileNameSTL, curve_type='Mean'):
colors = vtk.vtkNamedColors()
reader = vtk.vtkSTLReader()
reader.SetFileName(fileNameSTL)
reader.Update()
# Defining the curvature type.
curve = vtk.vtkCurvatures()
curve.SetInputConnection(reader.GetOutputPort())
if curve_type == "Mean":
curve.SetCurvatureTypeToMean()
else:
curve.SetCurvatureTypeToGaussian()
curve.Update()
# Applying color lookup table.
ctf = vtk.vtkColorTransferFunction()
ctf.SetColorSpaceToDiverging()
p1 = [0.0] + list(colors.GetColor3d("MidnightBlue"))
p2 = [1.0] + list(colors.GetColor3d("DarkOrange"))
ctf.AddRGBPoint(*p1)
ctf.AddRGBPoint(*p2)
cc = list()
for i in range(256):
cc.append(ctf.GetColor(float(i) / 255.0))
lut = vtk.vtkLookupTable()
lut.SetNumberOfColors(256)
for i, item in enumerate(cc):
lut.SetTableValue(i, item[0], item[1], item[2], 1.0)
lut.SetRange(0, 0) # In the case of kidney, the (0, 0) worked better.
lut.Build()
# Creating Mappers and Actors.
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(curve.GetOutputPort())
mapper.SetLookupTable(lut)
mapper.SetUseLookupTableScalarRange(1)
actor = vtk.vtkActor()
actor.SetMapper(mapper)
# Scalar values to numpy array. (Curvature).
dataObject = dsa.WrapDataObject(curve.GetOutput())
normals_array = dataObject.PointData['Normals'] # Output array.
curvature_array = dataObject.PointData['Mean_Curvature'] # output array.
# Node values to numpy array.
nodes = curve.GetOutput().GetPoints().GetData()
nodes_array = vtk_to_numpy(nodes)
# Creating a report file (.vtk file).
writer = vtk.vtkPolyDataWriter()
writer.SetFileName('vtk_file_generic.vtk')
writer.SetInputConnection(curve.GetOutputPort())
writer.Write()
# EDIT:
# Creating the point normal array using vtkPolyDataNormals().
normals = vtk.vtkPolyDataNormals()
normals.SetInputConnection(reader.GetOutputPort()) # Here "curve" could be replaced by "reader".
normals.ComputePointNormalsOn()
normals.SplittingOff()
normals.Update()
dataNormals = dsa.WrapDataObject(normals.GetOutput())
normals_array = dataNormals.PointData["Normals"]
return actor, normals_array, curvature_array, nodes_array
def render_scene(my_actor_list):
renderer = vtk.vtkRenderer()
for arg in my_actor_list:
renderer.AddActor(arg)
namedColors = vtk.vtkNamedColors()
renderer.SetBackground(namedColors.GetColor3d("SlateGray"))
window = vtk.vtkRenderWindow()
window.SetWindowName("Render Window")
window.AddRenderer(renderer)
interactor = vtk.vtkRenderWindowInteractor()
interactor.SetRenderWindow(window)
# Visualize
window.Render()
interactor.Start()
if __name__ == '__main__':
filename = "400_tri.stl"
my_list = list()
my_actor, my_normals, my_curve, my_nodes = curvature_to_numpy(filename, curve_type="Mean")
my_list.append(my_actor)
render_scene(my_list) # Visualization.
print(my_nodes) # Data points.
print(my_normals) # Normal vectors.
print(my_curve) # Mean curvatures.

How to make PyCollada output multiple meshes to the same scene?

So I am using pyCollada to try to export multiple meshes to the same scene. Alas, whenever I try to do so, I can only see one of the meshes I have loaded in. Am I doing something wrong when I create the file? Each individual mesh renders perfectly if I separate them into their own file, but they fail when I attempt to output them to the same file. I have looked through the API, but the documentation is very limited. Any help would be appreciated.
My code is listed shown below.
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 12 14:43:05 2015
#author: skylion
"""
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 11 11:01:48 2015
#author: danaukes
"""
import sys
import popupcad_deprecated
import popupcad_manufacturing_plugins
import popupcad
from popupcad.filetypes.design import Design
import PySide.QtGui as qg
#Draws Collada stuff
from collada import *
import numpy
geom_index = 0;
def exportBodyToMesh(output):
# csg = output.csg
generic = output.generic_laminate()
# layers = generic.layers()
layerdef = d.return_layer_definition()
layerdef.refreshzvalues()
# layers = layerdef.layers
mesh = Collada()
nodes = []
for layer in layerdef.layers:
shapes = generic.geoms[layer]#TODO Add it in for other shapes
zvalue = layerdef.zvalue[layer]
height = zvalue * 1/ popupcad.internal_argument_scaling
print zvalue
if (len(shapes) == 0) : #In case there are no shapes.
print "No shapes skipping"
continue
print shapes
for s in shapes:
geom = createMeshFromShape(s, height, mesh)
mesh.geometries.append(geom)
effect = material.Effect("effect" + str(geom_index), [], "phone", diffuse=(1,0,0), specular=(0,1,0))
mat = material.Material("material" + str(geom_index), "mymaterial", effect)
matnode = scene.MaterialNode("materialref" + str(geom_index), mat, inputs=[])
mesh.effects.append(effect)
mesh.materials.append(mat)
geomnode = scene.GeometryNode(geom, [matnode])
node = scene.Node("node" + str(geom_index), children=[geomnode])
nodes.append(node)
print nodes
myscene = scene.Scene("myscene", nodes)
mesh.scenes.append(myscene)
mesh.scene = myscene
# layer_num = layer_num + 1 #Add the layer thicknes instead of simply + 1
filename = str(output) + '.dae'
mesh.write(filename)
#TODO Add handling in case rigid body has already been selected.
print filename + " has been saved"
def createMeshFromShape(s,layer_num, mesh):
s.exteriorpoints()
a = s.triangles3()
vertices = []
global geom_index
for coord in a:
for dec in coord:
vertices.append(dec[0]) #x-axis
vertices.append(dec[1]) #y-axis
vertices.append(layer_num ) #z-axi
#This scales the verticies properly.
vert_floats = [x/popupcad.internal_argument_scaling for x in vertices]
vert_src = source.FloatSource("cubeverts-array" + str(geom_index), numpy.array(vert_floats), ('X', 'Y', 'Z'))
geom = geometry.Geometry(mesh, "geometry" + str(geom_index), "mycube", [vert_src])
input_list = source.InputList()
input_list.addInput(0, 'VERTEX', "#cubeverts-array" + str(geom_index))
indices = numpy.array(range(0,(len(vertices) / 3)));
triset = geom.createTriangleSet(indices, input_list, "materialref")
geom_index += 1
triset.generateNormals()
geom.primitives.append(triset)
return geom
#Start of actual script
print sys.argv
app = qg.QApplication('exporter.py')
d = Design.open()
print "Loading..."
d.reprocessoperations()
operation = d.operations[3] #Identify bodies
for output in operation.output:
exportBodyToMesh(output)
print "All objects printed"
#sys.exit(app.exec_())
Your code to add the geometry to the scene is outside your inner loop. You're only adding the last geometry to the scene, rather than all of them. You should be creating multiple GeometryNode and adding all of them to the Scene.

Python+Chaco+Traits - rendering bug: unexpected fills of line plot of large data?

Using the minimal example below, the line plot of a large (some 110k points) plot I get (with python 2.7, numpy 1.5.1, chaco/enable/traits 4.3.0) is this:
However, that is bizarre, because it is a line plot, and there shouldn't be any filled areas in there? Especially since the data is sawtooth-ish signal? It's as if there is a line at y~=37XX, above which there is color filling?! But sure enough, if I zoom into an area, I get the rendering I expect - without the unexpected fill:
Is this a bug - or is there something I'm doing wrong? I tried to use use_downsampling, but it makes no difference...
The test code:
import numpy as np
import numpy.random as npr
from pprint import pprint
from traits.api import HasTraits, Instance
from chaco.api import Plot, ArrayPlotData, VPlotContainer
from traitsui.api import View, Item
from enable.component_editor import ComponentEditor
from chaco.tools.api import PanTool, BetterSelectingZoom
tlen = 112607
alr = npr.randint(0, 4000, tlen)
tx = np.arange(0.0, 30.0-0.00001, 30.0/tlen)
ty = np.arange(0, tlen, 1) % 10000 + alr
pprint(len(ty))
class ChacoTest(HasTraits):
container = Instance(VPlotContainer)
traits_view = View(
Item('container', editor=ComponentEditor(), show_label=False),
width=800, height=500, resizable=True,
title="Chaco Test"
)
def __init__(self):
super(ChacoTest, self).__init__()
pprint(ty)
self.plotdata = ArrayPlotData(x = tx, y = ty)
self.plotobj = Plot(self.plotdata)
self.plotA = self.plotobj.plot(("x", "y"), type="line", color=(0,0.99,0), spacing=0, padding=0, alpha=0.7, use_downsampling=True)
self.container = VPlotContainer(self.plotobj, spacing=5, padding=5, bgcolor="lightgray")
#~ container.add(plot)
self.plotobj.tools.append(PanTool(self.plotobj))
self.plotobj.overlays.append(BetterSelectingZoom(self.plotobj))
if __name__ == "__main__":
ChacoTest().configure_traits()
I am able to reproduce the error and talking with John Wiggins (maintainer of Enable), it is a bug in kiva (which chaco uses to paint on the screen):
https://github.com/enthought/enable
The good news is that this is a bug in one of the kiva backend that you can use. So to go around the issue, you can run your script choosing a different backend:
ETS_TOOLKIT=qt4.qpainter python <NAME OF YOUR SCRIPT>
if you use qpainter or quartz, the plot looks (on my machine) as expected. If you choose qt4.image (the Agg backend), you will reproduce the issue. Unfortunately, the Agg backend is the default one. To change that, you can set the ETS_TOOLKIT environment variable to that value:
export ETS_TOOLKIT=qt4.qpainter
The bad news is that fixing this isn't going to be an easy task. Please feel free to report the bug in github (again https://github.com/enthought/enable) if you want to be involved in this. If you don't, I will log it in the next couple of days. Thanks for reporting it!
Just a note - I found this:
[Enthought-Dev] is chaco faster than matplotlib
I recall reading somewhere that you are expected to implement the
_downsample method because the optimal algorithm depends on the type
of data you're collecting.
And as I couldn't find any examples with _downsample implementation other than decimated_plot.py referred in that post, which isn't standalone - I tried and built a standalone example, included below.
The example basically has messed up drag and zoom, (plot disappears if you go out of range, or stretches upon a drag move) - and it starts zoomed in; but it is possible to zoom it out in the range shown in the OP - and then it displays the exact same plot rendering problem. So downsampling isn't the solution per se, so this is likely a bug?
import numpy as np
import numpy.random as npr
from pprint import pprint
from traits.api import HasTraits, Instance
from chaco.api import Plot, ArrayPlotData, VPlotContainer
from traitsui.api import View, Item
from enable.component_editor import ComponentEditor
from chaco.tools.api import PanTool, BetterSelectingZoom
#
from chaco.api import BaseXYPlot, LinearMapper, AbstractPlotData
from enable.api import black_color_trait, LineStyle
from traits.api import Float, Enum, Int, Str, Trait, Event, Property, Array, cached_property, Bool, Dict
from chaco.abstract_mapper import AbstractMapper
from chaco.abstract_data_source import AbstractDataSource
from chaco.array_data_source import ArrayDataSource
from chaco.data_range_1d import DataRange1D
tlen = 112607
alr = npr.randint(0, 4000, tlen)
tx = np.arange(0.0, 30.0-0.00001, 30.0/tlen)
ty = np.arange(0, tlen, 1) % 10000 + alr
pprint(len(ty))
class ChacoTest(HasTraits):
container = Instance(VPlotContainer)
traits_view = View(
Item('container', editor=ComponentEditor(), show_label=False),
width=800, height=500, resizable=True,
title="Chaco Test"
)
downsampling_cutoff = Int(4)
def __init__(self):
super(ChacoTest, self).__init__()
pprint(ty)
self.plotdata = ArrayPlotData(x = tx, y = ty)
self.plotobj = TimeSeriesPlot(self.plotdata)
self.plotobj.setplotranges("x", "y")
self.container = VPlotContainer(self.plotobj, spacing=5, padding=5, bgcolor="lightgray")
self.plotobj.tools.append(PanTool(self.plotobj))
self.plotobj.overlays.append(BetterSelectingZoom(self.plotobj))
# decimate from:
# https://bitbucket.org/mjrosen/neurobehavior/raw/097ef3719d1263a8b303d29c31ab71b6e792ab04/cns/widgets/views/decimated_plot.py
def decimate(data, screen_width, downsampling_cutoff=4, mode='extremes'):
data_width = data.shape[-1]
downsample = np.floor((data_width/screen_width)/4.)
if downsample > downsampling_cutoff:
return globals()['decimate_'+mode](data, downsample)
else:
return data
def decimate_extremes(data, downsample):
last_dim = data.ndim
offset = data.shape[-1] % downsample
if data.ndim == 2:
shape = (len(data), -1, downsample)
else:
shape = (-1, downsample)
data = data[..., offset:].reshape(shape).copy()
data_min = data.min(last_dim)
data_max = data.max(last_dim)
return data_min, data_max
def decimate_mean(data, downsample):
offset = len(data) % downsample
if data.ndim == 2:
shape = (-1, downsample, data.shape[-1])
else:
shape = (-1, downsample)
data = data[offset:].reshape(shape).copy()
return data.mean(1)
# based on class from decimated_plot.py, also
# neurobehavior/cns/chaco_exts/timeseries_plot.py ;
# + some other code from chaco
class TimeSeriesPlot(BaseXYPlot):
color = black_color_trait
line_width = Float(1.0)
line_style = LineStyle
reference = Enum('most_recent', 'trigger')
traits_view = View("color#", "line_width")
downsampling_cutoff = Int(100)
signal_trait = "updated"
decimate_mode = Str('extremes')
ch_index = Trait(None, Int, None)
# Mapping of data names from self.data to their respective datasources.
datasources = Dict(Str, Instance(AbstractDataSource))
index_mapper = Instance(AbstractMapper)
value_mapper = Instance(AbstractMapper)
def __init__(self, data=None, **kwargs):
super(TimeSeriesPlot, self).__init__(**kwargs)
self._index_mapper_changed(None, self.index_mapper)
self.setplotdata(data)
self._plot_ui_info = None
return
def setplotdata(self, data):
if data is not None:
if isinstance(data, AbstractPlotData):
self.data = data
elif type(data) in (ndarray, tuple, list):
self.data = ArrayPlotData(data)
else:
raise ValueError, "Don't know how to create PlotData for data" \
"of type " + str(type(data))
def setplotranges(self, index_name, value_name):
self.index_name = index_name
self.value_name = value_name
index = self._get_or_create_datasource(index_name)
value = self._get_or_create_datasource(value_name)
if not(self.index_mapper):
imap = LinearMapper()#(range=self.index_range)
self.index_mapper = imap
if not(self.value_mapper):
vmap = LinearMapper()#(range=self.value_range)
self.value_mapper = vmap
if not(self.index_range): self.index_range = DataRange1D() # calls index_mapper
if not(self.value_range): self.value_range = DataRange1D()
self.index_range.add(index) # calls index_mapper!
self.value_range.add(value)
# now do it (right?):
self.index_mapper = LinearMapper(range=self.index_range)
self.value_mapper = LinearMapper(range=self.value_range)
def _get_or_create_datasource(self, name):
if name not in self.datasources:
data = self.data.get_data(name)
if type(data) in (list, tuple):
data = array(data)
if isinstance(data, np.ndarray):
if len(data.shape) == 1:
ds = ArrayDataSource(data, sort_order="none")
elif len(data.shape) == 2:
ds = ImageData(data=data, value_depth=1)
elif len(data.shape) == 3:
if data.shape[2] in (3,4):
ds = ImageData(data=data, value_depth=int(data.shape[2]))
else:
raise ValueError("Unhandled array shape in creating new plot: " \
+ str(data.shape))
elif isinstance(data, AbstractDataSource):
ds = data
else:
raise ValueError("Couldn't create datasource for data of type " + \
str(type(data)))
self.datasources[name] = ds
return self.datasources[name]
def get_screen_points(self):
self._gather_points()
return self._downsample()
def _data_changed(self):
self.invalidate_draw()
self._cache_valid = False
self._screen_cache_valid = False
self.request_redraw()
def _gather_points(self):
if not self._cache_valid:
range = self.index_mapper.range
#if self.reference == 'most_recent':
# values, t_lb, t_ub = self.get_recent_range(range.low, range.high)
#else:
# values, t_lb, t_ub = self.get_range(range.low, range.high, -1)
values, t_lb, t_ub = self.data[self.value_name][range.low:range.high], range.low, range.high
#if self.ch_index is None:
# self._cached_data = values
#else:
# #self._cached_data = values[:,self.ch_index]
self._cached_data = values
self._cached_data_bounds = t_lb, t_ub
self._cache_valid = True
self._screen_cache_valid = False
def _downsample(self):
if not self._screen_cache_valid:
val_pts = self._cached_data
screen_min, screen_max = self.index_mapper.screen_bounds
screen_width = screen_max-screen_min
values = decimate(val_pts, screen_width, self.downsampling_cutoff,
self.decimate_mode)
if type(values) == type(()):
n = len(values[0])
s_val_min = self.value_mapper.map_screen(values[0])
s_val_max = self.value_mapper.map_screen(values[1])
self._cached_screen_data = s_val_min, s_val_max
else:
s_val_pts = self.value_mapper.map_screen(values)
self._cached_screen_data = s_val_pts
n = len(values)
t = np.linspace(*self._cached_data_bounds, num=n)
t_screen = self.index_mapper.map_screen(t)
self._cached_screen_index = t_screen
self._screen_cache_valid = True
return [self._cached_screen_index, self._cached_screen_data]
def _render(self, gc, points):
idx, val = points
if len(idx) == 0:
return
gc.save_state()
gc.set_antialias(True)
gc.clip_to_rect(self.x, self.y, self.width, self.height)
gc.set_stroke_color(self.color_)
gc.set_line_width(self.line_width)
#gc.set_line_width(5)
gc.begin_path()
#if len(val) == 2:
if type(val) == type(()):
starts = np.column_stack((idx, val[0]))
ends = np.column_stack((idx, val[1]))
gc.line_set(starts, ends)
else:
gc.lines(np.column_stack((idx, val)))
gc.stroke_path()
self._draw_default_axes(gc)
gc.restore_state()
if __name__ == "__main__":
ChacoTest().configure_traits()

Categories