today I developed a simple class in python in order to get a plot style suitable for my purpose ...
here the class :
this is the base class in which I define the colors ... in function colors , my particular interest is for the colors cycle named 'soft'
from abc import ABCMeta, abstractmethod
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
from matplotlib.axes import Axes
import matplotlib.pylab as pylab
import matplotlib
from cycler import cycler
class PlotBase(metaclass=ABCMeta):
def __init__(self, title , filename : str = ' '):
self.title = title
self.filename = filename
#----------------------------------------------------------------------------------------------
def findLimits(self):
'''
found the vectors limits in order to set the figure limits x,y
'''
_minX = 10E20
_minY = 10E20
_maxX = 0.000
_maxY = 0.000
for i in range(0,len(self.variable)-1,3):
if _minX >= min(self.variable[i]):
_minX = min(self.variable[i])
if _maxX <= max(self.variable[i]):
_maxX = max(self.variable[i])
if _minY >= min(self.variable[i+1]):
_minY = min(self.variable[i+1])
if _maxY <= max(self.variable[i+1]):
_maxY = max(self.variable[i+1])
return [round(_minX,2), round(_maxX,2) , round(_minY,2) , round(_maxY,2)]
#------------------------------------------------------------------------------------------------
def save(self, filename : 'str -> name of the file with extension'):
plt.savefig(filename)
#------------------------------------------------------------------------------------------------
def colors(self, style : 'str (name of the pallette of line colors)'):
if style == 'vega':
style = "cycler('color', ['1F77B4', 'FF7F0E', '2CA02C', 'D62728', '9467BD', '8C564B', 'E377C2', '7F7F7F', 'BCBD22', '17BECF'] )"
elif style == 'gg':
style = "cycler('color', ['E24A33', '348ABD', '988ED5', '777777', 'FBC15E', '8EBA42', 'FFB5B8'])"
elif style == 'brewer':
style = "cycler('color', ['66C2A5', 'FC8D62', '8DA0CB', 'E78AC3', 'A6D854', 'FFD92F', 'E5C494', 'B3B3B3'] )"
elif style == 'soft1':
style = "cycler('color', ['8DA0CB', 'E78AC3', 'A6D854', 'FFD92F', 'E5C494', 'B3B3B3', '66C2A5', 'FC8D62'] )"
elif style == 'tthmod':
style = "cycler('color', ['30a2da', 'fc4f30', 'e5ae38', '6d904f', '8b8b8b'])"
elif style == 'grayscale':
style = "cycler('color', ['252525', '525252', '737373', '969696', 'BDBDBD', 'D9D9D9', 'F0F0F0', 'F0F0FF' ])"
elif style == 'grayscale2':
style = "cycler('color', ['525252', '969696', 'BDBDBD' ,'D9D9D9', 'F0F0F0', 'F0F0FF' ])"
return style
#------------------------------------------------------------------------------------------------
def linestyle(self , linestyle : str):
if linestyle == 'linetype1':
linestyle = "cycler('linestyle', ['-', '--', ':', '-.'])"
if linestyle == 'linetype2':
linestyle = "cycler('linestyle', ['-', ':', '-.', '--'])"
return linestyle
#------------------------------------------------------------------------------------------------
#abstractmethod
def plot(self,*args,**kwargs):
"""
Abstract method!
the derived class must implement its self method
"""
pass
then I define the usable callable class :
class TimesPlot(PlotBase):
'''
mathpazo style (LaTeX-class) plot suitable for:
- palatino fonts template / beamer
- classic thesis style
- mathpazo package
'''
def __init__(self,title,filename: str= ' '):
super().__init__(title,filename)
#------------------------------------------------------------------------------------------------------------
def plot(self, *args,**kwargs): #
self.variable = [*args]
if len(self.variable) % 3 != 0:
print('Error variable must be coupled (even number)')
raise AttributeError('you must give 2 array x,y followed by string label for each plot')
'''
plot method, define all the parameter for the plot
the rendering of the figure is setting to beh "light"
--> TODO : define parameter in order to set the size of figure/font/linewidth
'''
#plt.rc('text', usetex=True )
#plt.rcParams['text.latex.preamble']=[r"\usepackage{times}"]
#plt.rcParams['text.latex.preamble']=[r"\usepackage{mathpazo}"]
plt.rcParams['font.family'] = 'serif' #'serif'
#plt.rcParams['font.sans-serif'] = ''#'DejaVu Sans' #'Tahoma' #, , 'Lucida Grande', 'Verdana']
plt.rcParams['font.size'] = 14
#plt.rcParams['font.name'] = 'Helvetica'
plt.rcParams['font.style'] = 'italic'
#plt.rc('font',family='' ,size=16, weight='normal')
plt.rc_context({'axes.edgecolor':'#999999' }) # BOX colors
plt.rc_context({'axes.linewidth':'1' }) # BOX width
plt.rc_context({'axes.xmargin':'0' })
plt.rc_context({'axes.ymargin':'0' })
plt.rc_context({'axes.labelcolor':'#555555' })
plt.rc_context({'axes.edgecolor':'999999' })
plt.rc_context({'axes.axisbelow':'True' })
plt.rc_context({'xtick.color':'#555555' }) # doesn't affect the text
plt.rc_context({'ytick.color':'#555555' }) # doesn't affect the text
plt.rc_context({ 'axes.prop_cycle': self.colors('soft1')})
#plt.rc('lines', linewidth=3)
fig,ax = plt.subplots(1,figsize=(10,6))
plt.title(self.title,color='#555555',fontsize=18)
plt.xlabel('time [s]',fontsize=16)
plt.ylabel('y(t)',fontsize=16)
#plt.grid(linestyle='dotted')
plt.grid(linestyle='--')
#plt.figure(1)
#ax.edgecolor('gray')
for i in range(0,len(self.variable)-1,3):
plt.plot(self.variable[i],self.variable[i+1], linewidth=3, label= self.variable[i+2])
ax.set_xlim( self.findLimits()[0] , self.findLimits()[1] )
ax.set_ylim( self.findLimits()[2] , self.findLimits()[3] + 0.02 )
majorLocator = MultipleLocator(20)
majorFormatter = FormatStrFormatter('%f')
minorXLocator = MultipleLocator(0.05)
minorYLocator = MultipleLocator(0.05)
ax.xaxis.set_minor_locator(minorXLocator)
ax.yaxis.set_minor_locator(minorYLocator)
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_ticks_position('both')
#axes.xmargin: 0
#axes.ymargin: 0
#plt.legend(fontsize=10)
#handles, labels = ax.get_legend_handles_labels()
#ax.legend(handles, labels)
#ax.legend(frameon=True, fontsize=12)
#text.set_color('gray')
#leg = plt.legend(framealpha = 0, loc = 'best')
#ax.legend(borderpad=1)
legend = leg = plt.legend(framealpha = 1, loc = 'best', fontsize=14,fancybox=False, borderpad =0.4)
leg.get_frame().set_edgecolor('#dddddd')
leg.get_frame().set_linewidth(1.2)
plt.setp(legend.get_texts(), color='#555555')
plt.tight_layout(0.5)
if self.filename != ' ':
super().save(self.filename)
plt.show()
The falls happens because in my university pc this class give me a plot using the soft1 scheme of colors ... at home with same version of python It use the default set of colors (corresponding to the set that I defined named 'vega') ... not only the line colors change ... also the box of the plot have different colors contrast ... could somebody help me ??
I call this class simply passing them 6 vector (in order to obtain 3 curve) as follow :
fig1 = makeplot.TimesPlot('Solution of Differential Equations','p1.pdf')
fig1.plot(fet,feu,'Explicit Euler',bet,beu,'Implicit Euler',x,y,'Analytical')
May you please help me to understand the reason why this happens ? I've tried in my home from archlinux and gentoo .... while in my desktop pc in which the class works correctly I use the last debian
Related
I am combining all defined function into a class and use if, elif to operate.
I will explain in the following.
First, I have a 3 types of plot, combo, line, and bar.
I know how to define function separately for these three plot.
Second, I want to combine these 3 plots together within a package using if.
The code I tried is:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
class AP(object):
def __init__(self, dt, date, group, value, value2, value3, value4, value5, value6, TYPE):
self.dt = dt
self.date = date
self.group= carrier
self.value = value
self.col1 = col1
self.col2 = col2
self.col3 = col3
self.col4 = col4
self.TYPE = TYPE
if self.TYPE == "combo":
def ComboChart(self, dt, date, group, value, TYPE):
dataset = pd.read_csv(dt)
dataset['date'] = pd.to_datetime(dataset[date])
dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
dataset['qtr'] = dataset['date'].dt.quarter
dataset = dataset.groupby([carrier, 'yq', 'qtr'])[value].sum().reset_index()
dataset['total.YQGR'] = dataset[value] / dataset.groupby(['qtr', carrier])[value].transform('shift') - 1
dataset = dataset[np.isfinite(dataset['total.YQGR'])]
dataset['total.R'] = dataset[value] / dataset.groupby(group)[value].transform('first')
dataset.yq = dataset.yq.astype(str)
fig, ax1 = plt.subplots(figsize=(12,7))
ax2=ax1.twinx()
sns.lineplot(x='yq',y='total.R', data=dataset, hue=group, ax=ax1, legend = None, palette = ('navy', 'r'), linewidth=5)
ax1.set_xticklabels(ax1.get_xticks(), rotation=45, fontsize=15, weight = 'heavy')
ax1.set_xlabel("", fontsize=15)
ax1.set_ylabel("")
ax1.set_ylim((0, max(dataset['total.R']) + 0.05))
sns.barplot(x='yq', y='total.YQGR', data=dataset, hue=group, ax=ax2, palette = ('navy', 'r'))
ax2.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax2.get_yticks()])
ax2.set_ylabel("")
ax2.set_ylim((min(dataset['total.YQGR']) - 0.01, max(dataset['total.YQGR']) + 0.2))
ax2.get_legend().remove()
ax2.legend(bbox_to_anchor=(-0.35, 0.5), loc=2, borderaxespad=0., fontsize = 'xx-large')
for groups in ax2.containers:
for bar in groups:
if bar.get_height() >= 0:
ax2.text(
bar.get_xy()[0] + bar.get_width()/1.5,
bar.get_height() + 0.003,
'{:.1f}%'.format(round(100*bar.get_height(),2)),
color='black',
horizontalalignment='center',
fontsize = 12, weight = 'heavy'
)
else:
ax2.text(
bar.get_xy()[0] + bar.get_width()/1.5,
bar.get_height() - 0.008,
'{:.1f}%'.format(round(100*bar.get_height(),2)),
color='black',
horizontalalignment='center',
fontsize = 12, weight = 'heavy'
)
ax1.yaxis.set_visible(False)
ax2.yaxis.set_visible(False)
ax2.xaxis.set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.spines["left"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.spines["left"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax1.set_title(TYPE, fontsize=20)
plt.show()
fig.savefig(TYPE, bbox_inches='tight', dpi=600)
elif self.TYPE == "line":
def line(self, dt, date, carrier, value, value2, TYPE):
dataset = pd.read_csv(dt)
dataset['date'] = pd.to_datetime(dataset[date])
dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
dataset = dataset.groupby([group, 'yq'])[value, value2].sum().reset_index()
dataset['Arate'] = dataset[value2] / dataset[value]
dataset.yq = dataset.yq.astype(str)
fig, ax1 = plt.subplots(figsize=(12,7))
sns.lineplot(x='yq', y='Arate', data=dataset, hue=group, ax=ax1, linewidth=5)
ax1.set_xticklabels(dataset['yq'], rotation=45, fontsize = 15)
ax1.set_xlabel("")
ax1.set_ylabel("")
ax1.set_ylim((min(dataset['Arate']) - 0.05, max(dataset['Arate']) + 0.05))
ax1.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax1.get_yticks()], fontsize = 18, weight = 'heavy')
ax1.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=2, borderaxespad=0., ncol = 6)
ax1.yaxis.grid(True)
ax1.spines["right"].set_visible(False)
ax1.spines["left"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax1.set_title(TYPE, fontsize = 20)
plt.show()
fig.savefig(TYPE, bbox_inches='tight', dpi=600)
elif self.TYPE == "bar":
def Bar(self, dt, date, group, value3, value4, value5, value6, TYPE):
dataset = pd.read_csv(dt, sep = '|')
dataset['date'] = pd.to_datetime(dataset[date])
dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
dataset = dataset.groupby([group, 'yq'])[value3, value4, value5, value6].sum().reset_index()
dataset = dataset.groupby([group]).tail(4)
dataset.yq = dataset.yq.astype(str)
dataset = pd.melt(dataset, id_vars = [group, 'yq'], value_vars = [value3, value4, value5, value6])
dataset = dataset.groupby(['variable', group]).value.sum().reset_index()
dataset['L4Qtr'] = dataset.value / dataset.groupby([group]).value.transform('sum')
fig, ax1 = plt.subplots(figsize=(12,7))
sns.barplot(x='variable', y='L4Qtr', data=dataset, hue=group, ax=ax1)
ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=17.5, weight = 'heavy')
ax1.set_xlabel("", fontsize=15)
ax1.set_ylabel("")
ax1.yaxis.set_ticks(np.arange(0, max(dataset['L4Qtr']) + 0.1, 0.05), False)
ax1.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax1.get_yticks()], fontsize = 18, weight = 'heavy')
ax1.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=2, borderaxespad=0., ncol = 6)
for groups in ax1.containers:
for bar in groups:
ax1.text(
bar.get_xy()[0] + bar.get_width()/2,
bar.get_height() + 0.005,
'{:.1f}%'.format(round(100*bar.get_height(),2)),
color=bar.get_facecolor(),
horizontalalignment='center',
fontsize = 16, weight = 'heavy'
)
ax1.spines["right"].set_visible(False)
ax1.spines["left"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax1.set_title(TYPE, fontsize=20)
plt.show()
fig.savefig(TYPE, bbox_inches='tight', dpi=600)
Third, I hope others can simply use this module as below:
import sys
sys.path.append(r'\\users\desktop\module')
from AP import AP as ap
Finally, when someone assign TYPE, it will automatically plot and save it.
# This will plot combo chart
ap(r'\\users\desktop\dataset.csv', date = 'DATEVALUE', group = 'GRPS', value = 'total', TYPE = 'combo')
Above is the ideal thought. I do not need to pass value2 ~ value6 in it since combo does not use them.
When I want bar:
# This will plot bar chart
ap(r'\\users\desktop\dataset.csv', date = 'DATEVALUE', group = 'GRPS', value3 = 'col1', value4 = 'col2', value5 = 'col3', value6 = 'col4', TYPE = 'combo')
My code is incorrect since error happened. It seems that I need to pass all parameters in it.
However, even I passed all parameters in it. No error but no output.
Any suggestion?
could you explain, why you don't just create subclasses for the types? Wouldn't that be more straight-forward?
1.) One way would be to make the subclasses visible to the user and if you don't like this,
2.) you could just create a kind of interface class (eg AP that hides the class that is used behind the scenes and for example instanciates as soon as the type is set.
3.) you can work as you began, but then I guess you would have to make the methods visible to the user, because I guess the way you implemented it, the functions are only visible in the init method (maybe your indentaion is not quite correct). For example if your if statements are executed by the init method, then you could assign the methods to instance variables like self.ComboChart= ComboChart to be able to call the method from outside. But imho that would not be very pythonic and a bit more hacky/less object oriented.
So I'd suggest 1.) and if that is not possible for some reason, then I'd go for solution 2. Both solutions also allow you to form a clean class structure and reuse code that way, while you are still able to build your simplified interface class if you like.
An example (pseudo code) for method 1 would look like below. Please note, that I haven't tested it, it is only meant to give you an idea, about splitting logic in an object oriented way. I didn't check your whole solution and so I don't know for example, if you always group your data in the same way. I'd proabably also separate the presentation logic from the data logic. That would especially be a good idea if you plan to display the same data in more ways, because with the current logic, you would reread the csv file and reporcess the data each time you want another represenatiation. So not to make it more complicated while I just want to explain the basic principle I ignored this and gave an example for a base class "Chart" and a subclass "ComboChart". The "ComboChart" class knows how to read/group the data, because it inherits the methods from "Chart", so you only have to implement it once and thus if you find a bug or want to enhance it later, you only need to do it in one place. The draw_chart method then only needs to do what's different according to the chosen representation. A user would have to create the instance of the subclass according the chart type they want to display and call display_chart().
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
class Chart(object):
def __init__(self, dt, date, group, value, value2, value3, value4, value5, value6):
self.dt = dt
self.date = date
self.group= carrier
self.value = value
self.col1 = col1
self.col2 = col2
self.col3 = col3
self.col4 = col4
self.TYPE = TYPE
self.dataset= None
def _read_data_(self)
dataset = pd.read_csv(dt)
dataset['date'] = pd.to_datetime(dataset[self.date])
dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
dataset['qtr'] = dataset['date'].dt.quarter
dataset = dataset.groupby([carrier, 'yq', 'qtr'])[value].sum().reset_index()
dataset['total.YQGR'] = dataset[value] / dataset.groupby(['qtr', carrier])[value].transform('shift') - 1
dataset = dataset[np.isfinite(dataset['total.YQGR'])]
dataset['total.R'] = dataset[value] / dataset.groupby(group)[value].transform('first')
dataset.yq = dataset.yq.astype(str)
self.dataset= dataset
return dataset
def get_data(self):
if self.dataset is None:
self._read_data_()
return self.dataset
def group_data(self):
dataset= self.get_data()
dataset = dataset.groupby([carrier, 'yq', 'qtr'])[value].sum().reset_index()
dataset['total.YQGR'] = dataset[value] / dataset.groupby(['qtr', carrier])[value].transform('shift') - 1
dataset = dataset[np.isfinite(dataset['total.YQGR'])]
dataset['total.R'] = dataset[value] / dataset.groupby(group)[value].transform('first')
dataset.yq = dataset.yq.astype(str)
return dataset
def draw_chart(self):
pass
class ComboChart(Chart):
def draw_chart(self):
dataset = self.group_data()
fig, ax1 = plt.subplots(figsize=(12,7))
ax2=ax1.twinx()
sns.lineplot(x='yq',y='total.R', data=dataset, hue=group, ax=ax1, legend = None, palette = ('navy', 'r'), linewidth=5)
ax1.set_xticklabels(ax1.get_xticks(), rotation=45, fontsize=15, weight = 'heavy')
ax1.set_xlabel("", fontsize=15)
ax1.set_ylabel("")
ax1.set_ylim((0, max(dataset['total.R']) + 0.05))
sns.barplot(x='yq', y='total.YQGR', data=dataset, hue=group, ax=ax2, palette = ('navy', 'r'))
ax2.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax2.get_yticks()])
ax2.set_ylabel("")
ax2.set_ylim((min(dataset['total.YQGR']) - 0.01, max(dataset['total.YQGR']) + 0.2))
ax2.get_legend().remove()
ax2.legend(bbox_to_anchor=(-0.35, 0.5), loc=2, borderaxespad=0., fontsize = 'xx-large')
for groups in ax2.containers:
for bar in groups:
if bar.get_height() >= 0:
ax2.text(
bar.get_xy()[0] + bar.get_width()/1.5,
bar.get_height() + 0.003,
'{:.1f}%'.format(round(100*bar.get_height(),2)),
color='black',
horizontalalignment='center',
fontsize = 12, weight = 'heavy'
)
else:
ax2.text(
bar.get_xy()[0] + bar.get_width()/1.5,
bar.get_height() - 0.008,
'{:.1f}%'.format(round(100*bar.get_height(),2)),
color='black',
horizontalalignment='center',
fontsize = 12, weight = 'heavy'
)
ax1.yaxis.set_visible(False)
ax2.yaxis.set_visible(False)
ax2.xaxis.set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.spines["left"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.spines["left"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax1.set_title(TYPE, fontsize=20)
plt.show()
fig.savefig(TYPE, bbox_inches='tight', dpi=600)
The second method (with the interface class) would just look the same, only that you have a forth class that is known to the user and knows how to call the real implementation. Like this:
class YourInterface:
def __init__(self, your_arguments, TYPE):
if TYPE == __ 'ComboChart':
self.client= ComboChart(your_arguments)
elif TYPE == ....
def display_chart(self):
self.client.display_chart()
But it's a pretty boring class, isnt't it?
I'd only do this if your class hierarchy is very technical and could change over time if you want to avoid that the users of your library build up dependencies on the real class hierarchy that would probably be broken as soon as you change your hierarchy. For most cases I guess, class hierarchies stay relatively stable, so you don't need such an extra level of abstraction created by an interface class.
I'm using matplotlib and PyQt5 in a GUI application. To plot my data I use the "FigureCanvasQTAgg" and add the "NavigationToolbar2QT" to be able to modify and save my plots. It works, but I was wondering if there are more advanved Toolbars that for example allow changing the font size of the titel and/or label? Here is what I use atm:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
self.figure = plt.figure()
self.ax = self.figure.add_subplot(111)
self.canvas = FigureCanvas(self.figure)
self.toolbar = NavigationToolbar(self.canvas)
The available "Figure options" look like this:
Options I'm looking for are:
font size of the title
font size of axis-label
options for the legend like position, font size, style
Probably I'm not the first one looking for these options, so I guess that somebody coded such an advanced toolbar already, but I couldn't find anything and thought it's worth asking here before I try to code it on my own and (probably) waste a lot of time.
The figure options qt dialog is defined in
https://github.com/matplotlib/matplotlib/blob/master/lib/matplotlib/backends/qt_editor/figureoptions.py
You may copy that code to a new file, say myfigureoptions.py and make the changes you want to. Then monkey-patch it into the original.
The following would add a title fontsize field.
# Copyright © 2009 Pierre Raybaut
# Licensed under the terms of the MIT License
# see the mpl licenses directory for a copy of the license
# Modified to add a title fontsize
"""Module that provides a GUI-based editor for matplotlib's figure options."""
import os.path
import re
import matplotlib
from matplotlib import cm, colors as mcolors, markers, image as mimage
import matplotlib.backends.qt_editor.formlayout as formlayout
from matplotlib.backends.qt_compat import QtGui
def get_icon(name):
basedir = os.path.join(matplotlib.rcParams['datapath'], 'images')
return QtGui.QIcon(os.path.join(basedir, name))
LINESTYLES = {'-': 'Solid',
'--': 'Dashed',
'-.': 'DashDot',
':': 'Dotted',
'None': 'None',
}
DRAWSTYLES = {
'default': 'Default',
'steps-pre': 'Steps (Pre)', 'steps': 'Steps (Pre)',
'steps-mid': 'Steps (Mid)',
'steps-post': 'Steps (Post)'}
MARKERS = markers.MarkerStyle.markers
def figure_edit(axes, parent=None):
"""Edit matplotlib figure options"""
sep = (None, None) # separator
# Get / General
# Cast to builtin floats as they have nicer reprs.
xmin, xmax = map(float, axes.get_xlim())
ymin, ymax = map(float, axes.get_ylim())
general = [('Title', axes.get_title()),
('Title Fontsize', axes.title.get_fontsize()), # <------------- HERE
sep,
(None, "<b>X-Axis</b>"),
('Left', xmin), ('Right', xmax),
('Label', axes.get_xlabel()),
('Scale', [axes.get_xscale(), 'linear', 'log', 'logit']),
sep,
(None, "<b>Y-Axis</b>"),
('Bottom', ymin), ('Top', ymax),
('Label', axes.get_ylabel()),
('Scale', [axes.get_yscale(), 'linear', 'log', 'logit']),
sep,
('(Re-)Generate automatic legend', False),
]
# Save the unit data
xconverter = axes.xaxis.converter
yconverter = axes.yaxis.converter
xunits = axes.xaxis.get_units()
yunits = axes.yaxis.get_units()
# Sorting for default labels (_lineXXX, _imageXXX).
def cmp_key(label):
match = re.match(r"(_line|_image)(\d+)", label)
if match:
return match.group(1), int(match.group(2))
else:
return label, 0
# Get / Curves
linedict = {}
for line in axes.get_lines():
label = line.get_label()
if label == '_nolegend_':
continue
linedict[label] = line
curves = []
def prepare_data(d, init):
"""Prepare entry for FormLayout.
`d` is a mapping of shorthands to style names (a single style may
have multiple shorthands, in particular the shorthands `None`,
`"None"`, `"none"` and `""` are synonyms); `init` is one shorthand
of the initial style.
This function returns an list suitable for initializing a
FormLayout combobox, namely `[initial_name, (shorthand,
style_name), (shorthand, style_name), ...]`.
"""
if init not in d:
d = {**d, init: str(init)}
# Drop duplicate shorthands from dict (by overwriting them during
# the dict comprehension).
name2short = {name: short for short, name in d.items()}
# Convert back to {shorthand: name}.
short2name = {short: name for name, short in name2short.items()}
# Find the kept shorthand for the style specified by init.
canonical_init = name2short[d[init]]
# Sort by representation and prepend the initial value.
return ([canonical_init] +
sorted(short2name.items(),
key=lambda short_and_name: short_and_name[1]))
curvelabels = sorted(linedict, key=cmp_key)
for label in curvelabels:
line = linedict[label]
color = mcolors.to_hex(
mcolors.to_rgba(line.get_color(), line.get_alpha()),
keep_alpha=True)
ec = mcolors.to_hex(
mcolors.to_rgba(line.get_markeredgecolor(), line.get_alpha()),
keep_alpha=True)
fc = mcolors.to_hex(
mcolors.to_rgba(line.get_markerfacecolor(), line.get_alpha()),
keep_alpha=True)
curvedata = [
('Label', label),
sep,
(None, '<b>Line</b>'),
('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
('Width', line.get_linewidth()),
('Color (RGBA)', color),
sep,
(None, '<b>Marker</b>'),
('Style', prepare_data(MARKERS, line.get_marker())),
('Size', line.get_markersize()),
('Face color (RGBA)', fc),
('Edge color (RGBA)', ec)]
curves.append([curvedata, label, ""])
# Is there a curve displayed?
has_curve = bool(curves)
# Get / Images
imagedict = {}
for image in axes.get_images():
label = image.get_label()
if label == '_nolegend_':
continue
imagedict[label] = image
imagelabels = sorted(imagedict, key=cmp_key)
images = []
cmaps = [(cmap, name) for name, cmap in sorted(cm.cmap_d.items())]
for label in imagelabels:
image = imagedict[label]
cmap = image.get_cmap()
if cmap not in cm.cmap_d.values():
cmaps = [(cmap, cmap.name)] + cmaps
low, high = image.get_clim()
imagedata = [
('Label', label),
('Colormap', [cmap.name] + cmaps),
('Min. value', low),
('Max. value', high),
('Interpolation',
[image.get_interpolation()]
+ [(name, name) for name in sorted(mimage.interpolations_names)])]
images.append([imagedata, label, ""])
# Is there an image displayed?
has_image = bool(images)
datalist = [(general, "Axes", "")]
if curves:
datalist.append((curves, "Curves", ""))
if images:
datalist.append((images, "Images", ""))
def apply_callback(data):
"""This function will be called to apply changes"""
orig_xlim = axes.get_xlim()
orig_ylim = axes.get_ylim()
general = data.pop(0)
curves = data.pop(0) if has_curve else []
images = data.pop(0) if has_image else []
if data:
raise ValueError("Unexpected field")
# Set / General
(title, titlefontsize, xmin, xmax, xlabel, xscale, # <------------- HERE
ymin, ymax, ylabel, yscale, generate_legend) = general
if axes.get_xscale() != xscale:
axes.set_xscale(xscale)
if axes.get_yscale() != yscale:
axes.set_yscale(yscale)
axes.set_title(title)
axes.title.set_fontsize(titlefontsize) # <------------- HERE
axes.set_xlim(xmin, xmax)
axes.set_xlabel(xlabel)
axes.set_ylim(ymin, ymax)
axes.set_ylabel(ylabel)
# Restore the unit data
axes.xaxis.converter = xconverter
axes.yaxis.converter = yconverter
axes.xaxis.set_units(xunits)
axes.yaxis.set_units(yunits)
axes.xaxis._update_axisinfo()
axes.yaxis._update_axisinfo()
# Set / Curves
for index, curve in enumerate(curves):
line = linedict[curvelabels[index]]
(label, linestyle, drawstyle, linewidth, color, marker, markersize,
markerfacecolor, markeredgecolor) = curve
line.set_label(label)
line.set_linestyle(linestyle)
line.set_drawstyle(drawstyle)
line.set_linewidth(linewidth)
rgba = mcolors.to_rgba(color)
line.set_alpha(None)
line.set_color(rgba)
if marker is not 'none':
line.set_marker(marker)
line.set_markersize(markersize)
line.set_markerfacecolor(markerfacecolor)
line.set_markeredgecolor(markeredgecolor)
# Set / Images
for index, image_settings in enumerate(images):
image = imagedict[imagelabels[index]]
label, cmap, low, high, interpolation = image_settings
image.set_label(label)
image.set_cmap(cm.get_cmap(cmap))
image.set_clim(*sorted([low, high]))
image.set_interpolation(interpolation)
# re-generate legend, if checkbox is checked
if generate_legend:
draggable = None
ncol = 1
if axes.legend_ is not None:
old_legend = axes.get_legend()
draggable = old_legend._draggable is not None
ncol = old_legend._ncol
new_legend = axes.legend(ncol=ncol)
if new_legend:
new_legend.set_draggable(draggable)
# Redraw
figure = axes.get_figure()
figure.canvas.draw()
if not (axes.get_xlim() == orig_xlim and axes.get_ylim() == orig_ylim):
figure.canvas.toolbar.push_current()
data = formlayout.fedit(datalist, title="Figure options", parent=parent,
icon=get_icon('qt4_editor_options.svg'),
apply=apply_callback)
if data is not None:
apply_callback(data)
# Monkey-patch original figureoptions
from matplotlib.backends.qt_editor import figureoptions # <------------- HERE
figureoptions.figure_edit = figure_edit
Use it as
import matplotlib.pyplot as plt
import myfigureoptions
fig, ax = plt.subplots()
ax.plot([1,2])
ax.set_title("My Title")
plt.show()
When clicking the figure options dialog you now have a title font size field.
Based on the answer of ImportanceOfBeingErnes I started to modify the matplotlib figure options in the way I need them for my app. I removed the left/right and up/down limits for the X-Axis and Y-Axis because I already got these options implemented in my GUI, but I added another tab for various legend options. This his how the figure options look now:
And here is the code of the current version (I had to access some private variables because I couldn't find the corresponding get-functions and also the variable naming might not be the best. Please feel free to revise my code):
# Copyright © 2009 Pierre Raybaut
# Licensed under the terms of the MIT License
# see the mpl licenses directory for a copy of the license
# Modified to add a title fontsize
"""Module that provides a GUI-based editor for matplotlib's figure options."""
import os.path
import re
import matplotlib
from matplotlib import cm, colors as mcolors, markers, image as mimage
import matplotlib.backends.qt_editor.formlayout as formlayout
from matplotlib.backends.qt_compat import QtGui
def get_icon(name):
basedir = os.path.join(matplotlib.rcParams['datapath'], 'images')
return QtGui.QIcon(os.path.join(basedir, name))
LINESTYLES = {'-': 'Solid',
'--': 'Dashed',
'-.': 'DashDot',
':': 'Dotted',
'None': 'None',
}
DRAWSTYLES = {
'default': 'Default',
'steps-pre': 'Steps (Pre)', 'steps': 'Steps (Pre)',
'steps-mid': 'Steps (Mid)',
'steps-post': 'Steps (Post)'}
MARKERS = markers.MarkerStyle.markers
def figure_edit(axes, parent=None):
"""Edit matplotlib figure options"""
sep = (None, None) # separator
# Get / General
# Cast to builtin floats as they have nicer reprs.
xmin, xmax = map(float, axes.get_xlim())
ymin, ymax = map(float, axes.get_ylim())
if 'labelsize' in axes.xaxis._major_tick_kw:
_ticksize = int(axes.xaxis._major_tick_kw['labelsize'])
else:
_ticksize = 15
general = [(None, "<b>Figure Title</b>"),
('Title', axes.get_title()),
('Font Size', int(axes.title.get_fontsize())),
sep,
(None, "<b>Axes settings</b>"),
('Label Size', int(axes.xaxis.label.get_fontsize())),
('Tick Size', _ticksize),
('Show grid', axes.xaxis._gridOnMajor),
sep,
(None, "<b>X-Axis</b>"),
('Label', axes.get_xlabel()),
('Scale', [axes.get_xscale(), 'linear', 'log', 'logit']),
sep,
(None, "<b>Y-Axis</b>"),
('Label', axes.get_ylabel()),
('Scale', [axes.get_yscale(), 'linear', 'log', 'logit'])
]
if axes.legend_ is not None:
old_legend = axes.get_legend()
_draggable = old_legend._draggable is not None
_ncol = old_legend._ncol
_fontsize = int(old_legend._fontsize)
_frameon = old_legend._drawFrame
_shadow = old_legend.shadow
_fancybox = type(old_legend.legendPatch.get_boxstyle()) == matplotlib.patches.BoxStyle.Round
_framealpha = old_legend.get_frame().get_alpha()
else:
_draggable = False
_ncol = 1
_fontsize = 15
_frameon = True
_shadow = True
_fancybox = True
_framealpha = 0.5
legend = [('Draggable', _draggable),
('columns', _ncol),
('Font Size', _fontsize),
('Frame', _frameon),
('Shadow', _shadow),
('FancyBox', _fancybox),
('Alpha', _framealpha)
]
# Save the unit data
xconverter = axes.xaxis.converter
yconverter = axes.yaxis.converter
xunits = axes.xaxis.get_units()
yunits = axes.yaxis.get_units()
# Sorting for default labels (_lineXXX, _imageXXX).
def cmp_key(label):
match = re.match(r"(_line|_image)(\d+)", label)
if match:
return match.group(1), int(match.group(2))
else:
return label, 0
# Get / Curves
linedict = {}
for line in axes.get_lines():
label = line.get_label()
if label == '_nolegend_':
continue
linedict[label] = line
curves = []
def prepare_data(d, init):
"""Prepare entry for FormLayout.
`d` is a mapping of shorthands to style names (a single style may
have multiple shorthands, in particular the shorthands `None`,
`"None"`, `"none"` and `""` are synonyms); `init` is one shorthand
of the initial style.
This function returns an list suitable for initializing a
FormLayout combobox, namely `[initial_name, (shorthand,
style_name), (shorthand, style_name), ...]`.
"""
if init not in d:
d = {**d, init: str(init)}
# Drop duplicate shorthands from dict (by overwriting them during
# the dict comprehension).
name2short = {name: short for short, name in d.items()}
# Convert back to {shorthand: name}.
short2name = {short: name for name, short in name2short.items()}
# Find the kept shorthand for the style specified by init.
canonical_init = name2short[d[init]]
# Sort by representation and prepend the initial value.
return ([canonical_init] +
sorted(short2name.items(),
key=lambda short_and_name: short_and_name[1]))
curvelabels = sorted(linedict, key=cmp_key)
for label in curvelabels:
line = linedict[label]
color = mcolors.to_hex(
mcolors.to_rgba(line.get_color(), line.get_alpha()),
keep_alpha=True)
ec = mcolors.to_hex(
mcolors.to_rgba(line.get_markeredgecolor(), line.get_alpha()),
keep_alpha=True)
fc = mcolors.to_hex(
mcolors.to_rgba(line.get_markerfacecolor(), line.get_alpha()),
keep_alpha=True)
curvedata = [
('Label', label),
sep,
(None, '<b>Line</b>'),
('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
('Width', line.get_linewidth()),
('Color (RGBA)', color),
sep,
(None, '<b>Marker</b>'),
('Style', prepare_data(MARKERS, line.get_marker())),
('Size', line.get_markersize()),
('Face color (RGBA)', fc),
('Edge color (RGBA)', ec)]
curves.append([curvedata, label, ""])
# Is there a curve displayed?
has_curve = bool(curves)
# Get / Images
imagedict = {}
for image in axes.get_images():
label = image.get_label()
if label == '_nolegend_':
continue
imagedict[label] = image
imagelabels = sorted(imagedict, key=cmp_key)
images = []
cmaps = [(cmap, name) for name, cmap in sorted(cm.cmap_d.items())]
for label in imagelabels:
image = imagedict[label]
cmap = image.get_cmap()
if cmap not in cm.cmap_d.values():
cmaps = [(cmap, cmap.name)] + cmaps
low, high = image.get_clim()
imagedata = [
('Label', label),
('Colormap', [cmap.name] + cmaps),
('Min. value', low),
('Max. value', high),
('Interpolation',
[image.get_interpolation()]
+ [(name, name) for name in sorted(mimage.interpolations_names)])]
images.append([imagedata, label, ""])
# Is there an image displayed?
has_image = bool(images)
datalist = [(general, "Axes", ""), (legend, "Legend", "")]
if curves:
datalist.append((curves, "Curves", ""))
if images:
datalist.append((images, "Images", ""))
def apply_callback(data):
"""This function will be called to apply changes"""
general = data.pop(0)
legend = data.pop(0)
curves = data.pop(0) if has_curve else []
images = data.pop(0) if has_image else []
if data:
raise ValueError("Unexpected field")
# Set / General
(title, titlesize, labelsize, ticksize, grid, xlabel, xscale,
ylabel, yscale) = general
if axes.get_xscale() != xscale:
axes.set_xscale(xscale)
if axes.get_yscale() != yscale:
axes.set_yscale(yscale)
axes.set_title(title)
axes.title.set_fontsize(titlesize)
axes.set_xlabel(xlabel)
axes.xaxis.label.set_size(labelsize)
axes.xaxis.set_tick_params(labelsize=ticksize)
axes.set_ylabel(ylabel)
axes.yaxis.label.set_size(labelsize)
axes.yaxis.set_tick_params(labelsize=ticksize)
axes.grid(grid)
# Restore the unit data
axes.xaxis.converter = xconverter
axes.yaxis.converter = yconverter
axes.xaxis.set_units(xunits)
axes.yaxis.set_units(yunits)
axes.xaxis._update_axisinfo()
axes.yaxis._update_axisinfo()
# Set / Legend
(leg_draggable, leg_ncol, leg_fontsize, leg_frameon, leg_shadow,
leg_fancybox, leg_framealpha, ) = legend
new_legend = axes.legend(ncol=leg_ncol,
fontsize=float(leg_fontsize),
frameon=leg_frameon,
shadow=leg_shadow,
framealpha=leg_framealpha,
fancybox=leg_fancybox)
new_legend.set_draggable(leg_draggable)
# Set / Curves
for index, curve in enumerate(curves):
line = linedict[curvelabels[index]]
(label, linestyle, drawstyle, linewidth, color, marker, markersize,
markerfacecolor, markeredgecolor) = curve
line.set_label(label)
line.set_linestyle(linestyle)
line.set_drawstyle(drawstyle)
line.set_linewidth(linewidth)
rgba = mcolors.to_rgba(color)
line.set_alpha(None)
line.set_color(rgba)
if marker is not 'none':
line.set_marker(marker)
line.set_markersize(markersize)
line.set_markerfacecolor(markerfacecolor)
line.set_markeredgecolor(markeredgecolor)
# Set / Images
for index, image_settings in enumerate(images):
image = imagedict[imagelabels[index]]
label, cmap, low, high, interpolation = image_settings
image.set_label(label)
image.set_cmap(cm.get_cmap(cmap))
image.set_clim(*sorted([low, high]))
image.set_interpolation(interpolation)
# Redraw
figure = axes.get_figure()
figure.canvas.draw()
data = formlayout.fedit(datalist, title="Figure options", parent=parent,
icon=get_icon('qt4_editor_options.svg'),
apply=apply_callback)
if data is not None:
apply_callback(data)
# Monkey-patch original figureoptions
from matplotlib.backends.qt_editor import figureoptions
figureoptions.figure_edit = figure_edit
Here is my issue: I have an embedded matplotlib figure in a Qt5 application. When I press the button "edit axis, curve and image parameter", I select my concerned subplot, but only the tab "axis" options appears. it is missing tabs for "curve" and "image".
actual picture
whereas I should have had something like this:
targeted picture
If anyone knows why...
Probably the answer is easy:
If there is no curve (line) in the plot, there will be no "Curves" tab.
If there is no image in the plot, there will be no "Images" tab.
class View2D(MapView):
def show(self, som, what='codebook', which_dim='all', cmap=None,
col_sz=None, desnormalize=False):
(self.width, self.height, indtoshow, no_row_in_plot, no_col_in_plot,
axis_num) = self._calculate_figure_params(som, which_dim, col_sz)
self.prepare()
if not desnormalize:
codebook = som.codebook.matrix
else:
codebook = som._normalizer.denormalize_by(som.data_raw, som.codebook.matrix)
if which_dim == 'all':
names = som._component_names[0]
elif type(which_dim) == int:
names = [som._component_names[0][which_dim]]
elif type(which_dim) == list:
names = som._component_names[0][which_dim]
while axis_num < len(indtoshow):
axis_num += 1
ax = plt.subplot(no_row_in_plot, no_col_in_plot, axis_num)
ind = int(indtoshow[axis_num-1])
min_color_scale = np.mean(codebook[:, ind].flatten()) - 1 * np.std(codebook[:, ind].flatten())
max_color_scale = np.mean(codebook[:, ind].flatten()) + 1 * np.std(codebook[:, ind].flatten())
min_color_scale = min_color_scale if min_color_scale >= min(codebook[:, ind].flatten()) else \
min(codebook[:, ind].flatten())
max_color_scale = max_color_scale if max_color_scale <= max(codebook[:, ind].flatten()) else \
max(codebook[:, ind].flatten())
norm = matplotlib.colors.Normalize(vmin=min_color_scale, vmax=max_color_scale, clip=True)
mp = codebook[:, ind].reshape(som.codebook.mapsize[0],
som.codebook.mapsize[1])
# pl = plt.pcolor(mp[::-1], norm=norm, cmap='jet')
pl = plt.imshow(mp[::-1], interpolation='nearest', origin='lower',cmap='jet')
plt.axis([0, som.codebook.mapsize[1], 0, som.codebook.mapsize[0]])
plt.title(names[axis_num - 1])
ax.set_yticklabels([])
ax.set_xticklabels([])
plt.colorbar(pl)
plt.show()
I'd like to create a barplot in matplotlib:
fig, ax = plt.subplots()
oldbar = ax.bar(x=ind, height=y, width=width)
I'd then like to pickle this barplot to file (either the dictionary or the axes - I'm not sure which is correct):
pickle.dump(oldbar, file('oldbar.pkl', 'w'))
I'd then like to reload this file, and then plot the old bar onto alongside a new bar plot, so I can compare them on a single axes:
fig, ax = plt.subplots()
newbar = ax.bar(x=ind, height=y, width=width)
oldbar = pickle.load(file('oldbar.pkl'))
# I realise the line below doesn't work
ax.bar(oldbar)
plt.show()
Ideally, I'd then like to present them as below. Any suggestions of how I might go about this?
You would pickle the figure instead the artists in it.
import matplotlib.pyplot as plt
import numpy as np
import pickle
ind = np.linspace(1,5,5)
y = np.linspace(9,1,5)
width = 0.3
fig, ax = plt.subplots()
ax.bar(x=ind, height=y, width=width)
ax.set_xlabel("x label")
pickle.dump(fig, file('oldbar.pkl', 'w'))
plt.close("all")
ind2 = np.linspace(1,5,5)
y2 = np.linspace(8,2,5)
width2 = 0.3
fig2 = pickle.load(file('oldbar.pkl'))
ax2 = plt.gca()
ax2.bar(x=ind2+width, height=y2, width=width2, color="C1")
plt.show()
However pickling the data itself may make more sense here.
import matplotlib.pyplot as plt
import numpy as np
import pickle
ind = np.linspace(1,5,5)
y = np.linspace(9,1,5)
width = 0.3
dic = {"ind":ind, "y":y, "width":width}
pickle.dump(dic, file('olddata.pkl', 'w'))
### new data
ind2 = np.linspace(1,5,5)
y2 = np.linspace(8,2,5)
width2 = 0.3
olddic = pickle.load(file('olddata.pkl'))
fig, ax = plt.subplots()
ax.bar(x=olddic["ind"], height=olddic["y"], width=olddic["width"])
ax.bar(x=ind2+olddic["width"], height=y2, width=width2)
ax.set_xlabel("x label")
plt.show()
Maybe this will help:
import pickle as pkl
import matplotlib.pyplot as plt
import numpy as np
class Data_set(object):
def __init__(self, x=[], y=[], name='data', pklfile=None,
figure=None, axes=None):
"""
"""
if pklfile is None:
self.x = np.asarray(x)
self.y = np.asarray(y)
self.name = str(name)
else:
self.unpickle(pklfile)
self.fig = figure
self.ax = axes
self.bar = None
def plot(self, width=0, offset=0, figure=None, axes=None):
if self.fig is None:
if figure is None:
self.fig = plt.figure()
self.ax = self.fig.subplots(1, 1)
else:
self.fig = figure
if axes is None:
self.ax = self.fig.subplots(1, 1)
else:
self.ax = axes
# maybe there's no need to keep track of self.fig, .ax and .bar,
# but just in case...
if figure is not None:
fig_to_use = figure
if axes is not None:
ax_to_use = axes
else:
ax_to_use = fig_to_use.subplots(1, 1)
else:
fig_to_use = self.fig
ax_to_use = self.ax
if not width:
width = (self.x[1]-self.x[0]) / 2.
self.bar = ax_to_use.bar(x=self.x+offset, height=self.y, width=width)
return fig_to_use, ax_to_use, self.bar
def pickle(self, filename='', ext='.pkl'):
if filename == '':
filename = self.name
with open(filename+ext, 'w') as output_file:
pkl.dump((self.name, self.x, self.y), output_file)
def unpickle(self, filename='', ext='.pkl'):
if filename == '':
filename = self.name
with open(filename + ext, 'r') as input_file:
# the name should really come from the filename, but then the
# above would be confusing?
self.name, self.x, self.y = pkl.load(input_file)
class Data_set_manager(object):
def __init__(self, datasets={}):
self.datasets = datasets
def add_dataset(self, data_set):
self.datasets[data_set.name] = data_set
def add_dataset_from_file(self, filename, ext='.pkl'):
self.datasets[filename] = Data_set(name=filename)
self.datasets[filename].unpickle(filename=filename, ext=ext)
def compare(self, width=0, offset=0, *args):
self.fig = plt.figure()
self.ax = self.fig.subplots(1, 1)
if len(args) == 0:
args = self.datasets.keys()
args.sort()
n = len(args)
if n == 0:
return None, None
if width == 0:
min_dx = None
for dataset in self.datasets.values():
sorted_x = dataset.x.copy()
sorted_x.sort()
try:
new_min_dx = np.min(dataset.x[1:] - dataset.x[:-1])
except ValueError:
# zero-size array to reduction operation minimum which
# has no identity (empty array)
new_min_dx = None
if new_min_dx < min_dx or min_dx is None:
min_dx = new_min_dx
if min_dx is None:
min_dx = 1.
width = float(min_dx) / (n + 1)
offset = float(min_dx) / (n + 1)
offsets = offset*np.arange(n)
if n % 2 == 0:
offsets -= offsets[n/2] - offset/2.
else:
offsets -= offsets[n/2]
i = 0
for name in args:
self.datasets.get(name, Data_set()).plot(width=width,
offset=offsets[i],
figure=self.fig,
axes=self.ax)
i += 1
self.ax.legend(args)
return self.fig, self.ax
if __name__ == "__main__":
# test saving/loading
name = 'test'
to_pickle = Data_set(x=np.arange(10),
y=np.random.rand(10),
name=name)
to_pickle.pickle()
unpickled = Data_set(pklfile=name)
print unpickled.name == to_pickle.name
# test comparison
blorg = Data_set_manager({})
x_step = 1.
n_bars = 4 # also try an odd number
for n in range(n_bars):
blorg.add_dataset(Data_set(x=x_step * np.arange(n_bars),
y=np.random.rand(n_bars),
name='teste' + str(n)))
fig, ax = blorg.compare()
fig.show()
It should work with both even and odd number of bars:
And as long as you keep a record of the names you've used (tip:look in the folder where you are saving them) you can reload the data and compare it with the new one.
More checks could be made (to make sure the file exists, that the x axis is something that can be subtracted before trying to do so, etc.), and it could also use some documentation and proper testing - but this should do in a hurry.
Matplotlib axes have Major and Minor ticks. How do I add a third level of tick below Minor?
For example
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker
t = np.arange(0.0, 100.0, 0.1)
s = np.sin(0.1*np.pi*t)*np.exp(-t*0.01)
fig, ax = plt.subplots()
plt.plot(t, s)
ax1 = ax.twiny()
ax1.plot(t, s)
ax1.xaxis.set_ticks_position('bottom')
majors = np.linspace(0, 100, 6)
minors = np.linspace(0, 100, 11)
thirds = np.linspace(0, 100, 101)
ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(majors))
ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(minors))
ax1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator([]))
ax1.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(thirds))
ax1.tick_params(which='minor', length=2)
ax.tick_params(which='minor', length=4)
ax.tick_params(which='major', length=6)
ax.grid(which='both',axis='x',linestyle='--')
plt.axhline(color='gray')
plt.show()
produces the effect I want using twinned x-axes.
Is there a better way?
As I stated that you can achieve what you want by deriving from some key classes, I decided to do so (but as I said, it's probably not worth the effort). Anyway, here is what I've got:
from matplotlib import pyplot as plt
from matplotlib import axes as maxes
from matplotlib import axis as maxis
import matplotlib.ticker as mticker
import matplotlib.cbook as cbook
from matplotlib.projections import register_projection
from matplotlib import ticker
import numpy as np
class SubMinorXAxis(maxis.XAxis):
def __init__(self,*args,**kwargs):
self.subminor = maxis.Ticker()
self.subminorTicks = []
self._subminor_tick_kw = dict()
super(SubMinorXAxis,self).__init__(*args,**kwargs)
def reset_ticks(self):
cbook.popall(self.subminorTicks)
##self.subminorTicks.extend([self._get_tick(major=False)])
self.subminorTicks.extend([maxis.XTick(self.axes, 0, '', major=False, **self._subminor_tick_kw)])
self._lastNumSubminorTicks = 1
super(SubMinorXAxis,self).reset_ticks()
def set_subminor_locator(self, locator):
"""
Set the locator of the subminor ticker
ACCEPTS: a :class:`~matplotlib.ticker.Locator` instance
"""
self.isDefault_minloc = False
self.subminor.locator = locator
locator.set_axis(self)
self.stale = True
def set_subminor_formatter(self, formatter):
"""
Set the formatter of the subminor ticker
ACCEPTS: A :class:`~matplotlib.ticker.Formatter` instance
"""
self.isDefault_minfmt = False
self.subminor.formatter = formatter
formatter.set_axis(self)
self.stale = True
def get_subminor_ticks(self, numticks=None):
'get the subminor tick instances; grow as necessary'
if numticks is None:
numticks = len(self.get_subminor_locator()())
if len(self.subminorTicks) < numticks:
# update the new tick label properties from the old
for i in range(numticks - len(self.subminorTicks)):
##tick = self._get_tick(major=False)
tick = maxis.XTick(self.axes, 0, '', major=False, **self._subminor_tick_kw)
self.subminorTicks.append(tick)
if self._lastNumSubminorTicks < numticks:
protoTick = self.subminorTicks[0]
for i in range(self._lastNumSubminorTicks, len(self.subminorTicks)):
tick = self.subminorTicks[i]
tick.gridOn = False
self._copy_tick_props(protoTick, tick)
self._lastNumSubminorTicks = numticks
ticks = self.subminorTicks[:numticks]
return ticks
def set_tick_params(self, which='major', reset=False, **kwargs):
if which == 'subminor':
kwtrans = self._translate_tick_kw(kwargs, to_init_kw=True)
if reset:
self.reset_ticks()
self._subminor_tick_kw.clear()
self._subminor_tick_kw.update(kwtrans)
for tick in self.subminorTicks:
tick._apply_params(**self._subminor_tick_kw)
else:
super(SubMinorXAxis, self).set_tick_params(which=which, reset=reset, **kwargs)
def cla(self):
'clear the current axis'
self.set_subminor_locator(mticker.NullLocator())
self.set_subminor_formatter(mticker.NullFormatter())
super(SubMinorXAxis,self).cla()
def iter_ticks(self):
"""
Iterate through all of the major and minor ticks.
...and through the subminors
"""
majorLocs = self.major.locator()
majorTicks = self.get_major_ticks(len(majorLocs))
self.major.formatter.set_locs(majorLocs)
majorLabels = [self.major.formatter(val, i)
for i, val in enumerate(majorLocs)]
minorLocs = self.minor.locator()
minorTicks = self.get_minor_ticks(len(minorLocs))
self.minor.formatter.set_locs(minorLocs)
minorLabels = [self.minor.formatter(val, i)
for i, val in enumerate(minorLocs)]
subminorLocs = self.subminor.locator()
subminorTicks = self.get_subminor_ticks(len(subminorLocs))
self.subminor.formatter.set_locs(subminorLocs)
subminorLabels = [self.subminor.formatter(val, i)
for i, val in enumerate(subminorLocs)]
major_minor = [
(majorTicks, majorLocs, majorLabels),
(minorTicks, minorLocs, minorLabels),
(subminorTicks, subminorLocs, subminorLabels),
]
for group in major_minor:
for tick in zip(*group):
yield tick
class SubMinorAxes(maxes.Axes):
name = 'subminor'
def _init_axis(self):
self.xaxis = SubMinorXAxis(self)
self.spines['top'].register_axis(self.xaxis)
self.spines['bottom'].register_axis(self.xaxis)
self.yaxis = maxis.YAxis(self)
self.spines['left'].register_axis(self.yaxis)
self.spines['right'].register_axis(self.yaxis)
register_projection(SubMinorAxes)
if __name__ == '__main__':
fig = plt.figure()
ax = fig.add_subplot(111,projection = 'subminor')
t = np.arange(0.0, 100.0, 0.1)
s = np.sin(0.1*np.pi*t)*np.exp(-t*0.01)
majors = np.linspace(0, 100, 6)
minors = np.linspace(0, 100, 11)
thirds = np.linspace(0, 100, 101)
ax.plot(t, s)
ax.xaxis.set_ticks_position('bottom')
ax.xaxis.set_major_locator(ticker.FixedLocator(majors))
ax.xaxis.set_minor_locator(ticker.FixedLocator(minors))
ax.xaxis.set_subminor_locator(ticker.FixedLocator(thirds))
##some things in set_tick_params are not being set correctly
##by default. For instance 'top=False' must be stated
##explicitly
ax.tick_params(which='subminor', length=2, top=False)
ax.tick_params(which='minor', length=4)
ax.tick_params(which='major', length=6)
ax.grid(which='both',axis='x',linestyle='--')
plt.show()
It's not perfect, but for the use case you provided it's working fine. I drew some ideas from this matplotlib example and by going through the source codes directly. The result looks like this:
I tested the code on both Python 2.7 and Python 3.5.
EDIT:
I noticed that the subminor gridlines would always be drawn if the grid is turned on (while I had intended for it not to be drawn at all). I rectified this in the code above, i.e. the subminor ticks should never produce grid lines. If gridlines should be implemented properly, some more work will be needed.