Legends are printing twice when calling matplotlib subplots - python

I'm writing a code in matplotlib to print multiple histograms under a subplot grid, however, when I call the fig.legend() function at the end, legends from each plot are printing twice. Any guidance on how to resolve this issue would be greatly appreciated:)
Here is my code:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
def get_cmap(n, name='hsv'):
return plt.cm.get_cmap(name, n)
def isSqrt(n):
sq_root = int(np.sqrt(n))
return (sq_root*sq_root) == n
df = pd.read_csv('mpg.csv')
df2 = pd.read_csv('dm_office_sales.csv')
df['miles'] = df2['salary']
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n)
if(isSqrt(n)):
nrows = ncols = int(np.sqrt(n))
else:
ncols = int(np.sqrt(n))
for i in range(ncols,50):
if ncols*i >= n:
nrows = i
break
else:
pass
fig,ax = plt.subplots(nrows,ncols)
count = 0
print(nrows,ncols)
for i in range(0,nrows,1):
for j in range(0,ncols,1):
print('ncols = {}'.format(j),'nrows = {}'.format(i),'count = {}'.format(count))
if count<=n-1:
plt_new = sns.histplot(df[columns[count]],ax=ax[i,j],facecolor=cmap(count),kde=True,edgecolor='black',label=df[columns[count]].name)
patches = plt_new.get_children()
for patch in patches:
patch.set_alpha(0.8)
color = patches[0].get_facecolor()
ax[i,j].set_xlabel('{}'.format(df[columns[count]].name))
ax[i,j].xaxis.label.set_fontsize(10)
ax[i,j].xaxis.label.set_fontname('ariel')
ax[i,j].set(xlabel=None)
ax[i,j].tick_params(axis='y', labelsize=8)
count+=1
else:
break
for i in range(0,nrows,1):
for j in range(0,ncols,1):
if not ax[i,j].has_data():
fig.delaxes(ax[i,j])
else:
pass
plt.suptitle('Histograms').set_fontname('ariel')
plt.tight_layout()
fig.legend(loc='upper right')
plt.show()
Here is the output:

sns.histplot seems to create two bar containers. First a dummy one, and then the real one. (Tested with seaborn 0.12.1; this might work different in other versions.) Therefore, the label gets assigned to both the dummy and the real bar container. A workaround would be to remove the label of the dummy bar container.
Here is the adapted code. Seaborn's mpg dataset is used to have an easily reproducible example. As the first and last color of the hls colormap is red, get_cmap(n + 1) ensures n different colors are choosen. Some superfluous code has been removed.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
def get_cmap(n, name='hsv'):
return plt.cm.get_cmap(name, n)
sns.set_style('darkgrid')
df = sns.load_dataset('mpg')
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n + 1)
ncols = int(np.sqrt(n))
nrows = int(np.ceil(n / ncols))
fig, ax = plt.subplots(nrows, ncols)
count = 0
print(nrows, ncols)
for i in range(0, nrows):
for j in range(0, ncols):
if count < n:
# print('ncols = {j}; nrows = {i}; count = {count}')
sns.histplot(df[columns[count]], ax=ax[i, j], facecolor=cmap(count), kde=True, edgecolor='black',
label=df[columns[count]].name)
ax[i, j].containers[0].set_label('') # seaborn seems to create a dummy bar container, remove its label
for patch in ax[i, j].get_children():
patch.set_alpha(0.8)
ax[i, j].tick_params(axis='y', labelsize=8)
count += 1
for i in range(0, nrows):
for j in range(0, ncols):
if not ax[i, j].has_data():
fig.delaxes(ax[i, j])
plt.suptitle('Histograms').set_fontname('ariel')
fig.legend(loc='upper right')
plt.tight_layout()
plt.subplots_adjust(right=0.75) # make extra space for the legend
plt.show()
Upon further investigation, it seems the dummy bar container isn't created when sns.histplot is called with color= instead of facecolor=.
The code could also be written a bit more "pythonic". This means a.o. trying to avoid repeating code and explicit indices. To achieve this, zip is an important helper. Alongside avoiding repetition, the code becomes shorter and easier to modify. Once you get used to it, it becomes easier to read and to reason about.
The main part could look like e.g.:
fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
for column, ax, color in zip(columns, axs.flat, cmap(range(n))):
# using `color=` instead of `facecolor=` seems to avoid the creating of dummy bars
sns.histplot(df[column], ax=ax, color=color, kde=True, edgecolor='black', label=column)
for patch in ax.get_children():
patch.set_alpha(0.8)
ax.tick_params(axis='y', labelsize=8)
for ax in axs.flat:
if not ax.has_data():
fig.delaxes(ax)

Related

Seaborn boxplot : set median color and set tick label colors to boxes color

I'm using this nice boxplot graph, answer from #Parfait.
I got an out of bound error on j and had to use range(i*5,i*5+5). Why?
I'd like to set the median to a particular color, let's say red. medianprops=dict(color="red") won't work. How to do it?
How to set the y-axis tick labels to the same color as the boxes?
Disclaimer: I don't know what I'm doing.
Here's the code using random data :
# import the required library
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import string
import matplotlib.colors as mc
import colorsys
# data
df = pd.DataFrame(np.random.normal(np.random.randint(5,15),np.random.randint(1,5),size=(100, 16)), columns=list(string.ascii_uppercase)[:16])
# Boxplot
fig, ax = plt.subplots(figsize=(9, 10))
medianprops=dict(color="red")
ax = sns.boxplot(data=df, orient="h", showfliers=False, palette = "husl")
ax = sns.stripplot(data=df, orient="h", jitter=True, size=7, alpha=0.5, palette = "husl") # show data points
ax.set_title("Title")
plt.xlabel("X label")
def lighten_color(color, amount=0.5):
# --------------------- SOURCE: #IanHincks ---------------------
try:
c = mc.cnames[color]
except:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
for i,artist in enumerate(ax.artists):
# Set the linecolor on the artist to the facecolor, and set the facecolor to None
col = lighten_color(artist.get_facecolor(), 1.2)
artist.set_edgecolor(col)
# Each box has 6 associated Line2D objects (to make the whiskers, fliers, etc.)
# Loop over them here, and use the same colour as above
for j in range(i*5,i*5+5):
line = ax.lines[j]
line.set_color(col)
line.set_mfc(col)
line.set_mec(col)
#line.set_linewidth(0.5)
To change the color of the median, you can use the medianprops in sns.boxplot(..., medianprops=...). If you also set a unique label, that label can be tested again when iterating through the lines.
To know how many lines belong to each boxplot, you can divide the number of lines by the number of artists (just after the boxplot has been created, before other elements have been added to the plot). Note that a line potentially has 3 colors: the line color, the marker face color and the marker edge color. Matplotlib creates the fliers as an invisible line with markers. The code below thus also changes these colors to make it more robust to different options and possible future changes.
Looping simultaneously through the boxes and the y tick labels allows copying the color. Making them a bit larger and darker helps for readability.
import matplotlib.pyplot as plt
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb, to_rgb
import seaborn as sns
import pandas as pd
import numpy as np
def enlighten(color, factor=0.5):
h, s, v = rgb_to_hsv(to_rgb(color))
return hsv_to_rgb((h, s, 1 - factor * (1 - v)))
def endarken(color, factor=0.5):
h, s, v = rgb_to_hsv(to_rgb(color))
return hsv_to_rgb((h, s, factor * v))
df = pd.DataFrame(np.random.normal(1, 5, size=(100, 16)).cumsum(axis=0),
columns=['Hydrogen', 'Helium', 'Lithium', 'Beryllium', 'Boron', 'Carbon', 'Nitrogen', 'Oxygen',
'Fluorine', 'Neon', 'Sodium', 'Magnesium', 'Aluminum', 'Silicon', 'Phosphorus', 'Sulfur'])
sns.set_style('white')
fig, ax = plt.subplots(figsize=(9, 10))
colors = sns.color_palette("husl", len(df.columns))
sns.boxplot(data=df, orient="h", showfliers=False, palette='husl',
medianprops=dict(color="yellow", label='median'), ax=ax)
lines_per_boxplot = len(ax.lines) // len(ax.artists)
for i, (box, ytick) in enumerate(zip(ax.artists, ax.get_yticklabels())):
ytick.set_color(endarken(box.get_facecolor()))
ytick.set_fontsize(20)
color = enlighten(box.get_facecolor())
box.set_color(color)
for lin in ax.lines[i * lines_per_boxplot: (i + 1) * lines_per_boxplot]:
if lin.get_label() != 'median':
lin.set_color(color)
lin.set_markerfacecolor(color)
lin.set_markeredgecolor(color)
sns.stripplot(data=df, orient="h", jitter=True, size=7, alpha=0.5, palette='husl', ax=ax)
sns.despine(ax=ax)
ax.set_title("Title")
ax.set_xlabel("X label")
plt.tight_layout()
plt.show()
I just answer point 2. of my question.
After tinkering, I found this to work :
# Each box has 5 associated Line2D objects (the whiskers and median)
# Loop over them here, and use the same colour as above
n=5 # this was for tinkering
for j in range(i*n,i*n+n):
if j != i*n+4 : line = ax.lines[j] # not the median
line.set_color(col)
Again, I don't know what I'm doing. So someone more knowledgeable may provide a more valuable answer.
I removed the stripplot for better clarity.

Matplotlib how to move axis along data in a real-time animation

I'm trying to plot data that is generated in runtime. In order to do so I'm using matplotlib.animation.FuncAnimation.
While the data is displayed correctly, the axis values are not updating accordingly to the values that are being displayed:
The x axis displays values from 0 to 10 eventhough I update them in every iteration in the update_line function (see code below).
DataSource contains the data vector and appends values at runtime, and also returns the indexes of the values being returned:
import numpy as np
class DataSource:
data = []
display = 10
# Append one random number and return last 10 values
def getData(self):
self.data.append(np.random.rand(1)[0])
if(len(self.data) <= self.display):
return self.data
else:
return self.data[-self.display:]
# Return the index of the last 10 values
def getIndexVector(self):
if(len(self.data) <= self.display):
return list(range(len(self.data)))
else:
return list(range(len(self.data)))[-self.display:]
I've obtained the plot_animation function from the matplotlib docs.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from datasource import DataSource
def update_line(num, source, line):
data = source.getData()
indexs = source.getIndexVector()
if indexs[0] != 0:
plt.xlim(indexs[0], indexs[-1])
dim=np.arange(indexs[0],indexs[-1],1)
plt.xticks(dim)
line.set_data(indexs,data)
return line,
def plot_animation():
fig1 = plt.figure()
source = DataSource()
l, = plt.plot([], [], 'r-')
plt.xlim(0, 10)
plt.ylim(0, 1)
plt.xlabel('x')
plt.title('test')
line_ani = animation.FuncAnimation(fig1, update_line, fargs=(source, l),
interval=150, blit=True)
# To save the animation, use the command: line_ani.save('lines.mp4')
plt.show()
if __name__ == "__main__":
plot_animation()
How can I update the x axis values in every iteration of the animation?
(I appreciate suggestions to improve the code if you see any mistakes, eventhough they might not be related to the question).
Here is a simple case of how you can achieve this.
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
%matplotlib notebook
#data generator
data = np.random.random((100,))
#setup figure
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(1,1,1)
#rolling window size
repeat_length = 25
ax.set_xlim([0,repeat_length])
ax.set_ylim([-2,2])
#set figure to be modified
im, = ax.plot([], [])
def func(n):
im.set_xdata(np.arange(n))
im.set_ydata(data[0:n])
if n>repeat_length:
lim = ax.set_xlim(n-repeat_length, n)
else:
lim = ax.set_xlim(0,repeat_length)
return im
ani = animation.FuncAnimation(fig, func, frames=data.shape[0], interval=30, blit=False)
plt.show()
#ani.save('animation.gif',writer='pillow', fps=30)
Solution
My problem was in the following line:
line_ani = animation.FuncAnimation(fig1, update_line, fargs=(source, l),
interval=150, blit=True)
What I had to do is change the blit parameter to False and the x axis started to move as desired.

python matplotlib shared xlabel description / title on multiple subplots for animation

I'm using the following code to produce an animation with matplotlib that is intended to visualize my experiments.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation, PillowWriter
plt.rcParams['animation.html'] = 'jshtml'
def make_grid(X, description=None, labels=None, title_fmt="label: {}", cmap='gray', ncols=3, colors=None):
L = len(X)
nrows = -(-L // ncols)
frame_plot = []
for i in range(L):
plt.subplot(nrows, ncols, i + 1)
im = plt.imshow(X[i].squeeze(), cmap=cmap, interpolation='none')
if labels is not None:
color = 'k' if colors is None else colors[i]
plt.title(title_fmt.format(labels[i]), color=color)
plt.xticks([])
plt.yticks([])
frame_plot.append(im)
return frame_plot
def animate_step(X):
return X ** 2
n_splots = 6
X = np.random.random((n_splots,32,32,3))
Y = X
X_t = []
for i in range(10):
Y = animate_step(Y)
X_t.append((Y, i))
frames = []
for X, step in X_t:
frame = make_grid(X,
description="step={}".format(step),
labels=range(n_splots),
title_fmt="target: {}")
frames.append(frame)
anim = ArtistAnimation(plt.gcf(), frames,
interval=300, repeat_delay=8000, blit=True)
plt.close()
anim.save("test.gif", writer=PillowWriter())
anim
The result can be seen here:
https://i.stack.imgur.com/OaOsf.gif
It works fine so far, but I'm having trouble getting a shared xlabel to add a description for all of the 6 subplots in the animation. It is supposed to show what step the image is on, i.e. "step=5".
Since it is an animation, I cannot use xlabel or set_title (since it would be constant over the whole animation) and have to draw the text myself.
I've tried something along the lines of..
def make_grid(X, description=None, labels=None, title_fmt="label: {}", cmap='gray', ncols=3, colors=None):
L = len(X)
nrows = -(-L // ncols)
frame_plot = []
desc = plt.text(0.5, .04, description,
size=plt.rcparams["axes.titlesize"],
ha="center",
transform=plt.gca().transAxes
)
frame_plot.append(desc)
...
This, of course, won't work, because the axes are not yet created. I tried using the axis of another subplot(nrows, 1, nrows), but then the existing images are drawn over..
Does anyone have a solution to this?
Edit:
unclean, hacky solution for now:
Wait for the axes of the middle image of the last row to be created and use that for plotting the text.
In the for loop:
...
if i == int((nrows - 0.5) * ncols):
title = ax.text(0.25, -.3, description,
size=plt.rcParams["axes.titlesize"],
# ha="center",
transform=ax.transAxes
)
frame_plot.append(title)
...
To me, your case is easier to solve with FuncAnimation instead of ArtistAnimation, even if you already have access to the full list of data you want to show animated (see this thread for a discussion about the difference between the two functions).
Inspired from this FuncAnimation example, I wrote the code below that does what you needed (using the same code with ArtistAnimation and correct list of arguments does not work).
The main idea is to initialize all elements to be animated at the beginning, and to update them over the animation frames. This can be done for the text object (step_txt = fig.text(...)) in charge of displaying the current step, and for the images out from ax.imshow. You can then update whatever object you would like to see animated with this recipe.
Note that the technique works if you want the text to be an x_label or any text you choose to show. See the commented line in the code.
#!/Users/seydoux/anaconda3/envs/jupyter/bin/python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
# parameters
n_frames = 10
n_splots = 6
n_cols = 3
n_rows = n_splots // n_cols
def update_data(x):
return x ** 2
# create all snapshots
snapshots = [np.random.rand(n_splots, 32, 32, 3)]
for _ in range(n_frames):
snapshots.append(update_data(snapshots[-1]))
# initialize figure and static elements
fig, axes = plt.subplots(2, 3)
axes = axes.ravel() # so we can access all axes with a single index
for i, ax in enumerate(axes):
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("target: {}".format(i))
# initialize elements to be animated
step_txt = fig.text(0.5, 0.95, "step: 0", ha="center", weight="bold")
# step_txt = axes[4].set_xlabel("step: 0") # also works with x_label
imgs = list()
for a, s in zip(axes, snapshots[0]):
imgs.append(a.imshow(s, interpolation="none", cmap="gray"))
# animation function
def animate(i):
# update images
for img, s in zip(imgs, snapshots[i]):
img.set_data(s)
# update text
step_txt.set_text("step: {}".format(i))
# etc
anim = FuncAnimation(fig, animate, frames=n_frames, interval=300)
anim.save("test.gif", writer=PillowWriter())
Here is the output I got from the above code:

Plot multiple Y axes

I know pandas supports a secondary Y axis, but I'm curious if anyone knows a way to put a tertiary Y axis on plots. Currently I am achieving this with numpy+pyplot, but it is slow with large data sets.
This is to plot different measurements with distinct units on the same graph for easy comparison (eg: Relative Humidity/Temperature/ and Electrical Conductivity).
So really just curious if anyone knows if this is possible in pandas without too much work.
[Edit] I doubt that there is a way to do this(without too much overhead) however I hope to be proven wrong, as this may be a limitation of matplotlib.
I think this might work:
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
df = DataFrame(np.random.randn(5, 3), columns=['A', 'B', 'C'])
fig, ax = plt.subplots()
ax3 = ax.twinx()
rspine = ax3.spines['right']
rspine.set_position(('axes', 1.15))
ax3.set_frame_on(True)
ax3.patch.set_visible(False)
fig.subplots_adjust(right=0.7)
df.A.plot(ax=ax, style='b-')
# same ax as above since it's automatically added on the right
df.B.plot(ax=ax, style='r-', secondary_y=True)
df.C.plot(ax=ax3, style='g-')
# add legend --> take advantage of pandas providing us access
# to the line associated with the right part of the axis
ax3.legend([ax.get_lines()[0], ax.right_ax.get_lines()[0], ax3.get_lines()[0]],\
['A','B','C'], bbox_to_anchor=(1.5, 0.5))
Output:
A simpler solution without plt:
ax1 = df1.plot()
ax2 = ax1.twinx()
ax2.spines['right'].set_position(('axes', 1.0))
df2.plot(ax=ax2)
ax3 = ax1.twinx()
ax3.spines['right'].set_position(('axes', 1.1))
df3.plot(ax=ax3)
....
Using function to achieve this:
def plot_multi(data, cols=None, spacing=.1, **kwargs):
from pandas.plotting._matplotlib.style import get_standard_colors
# Get default color style from pandas - can be changed to any other color list
if cols is None: cols = data.columns
if len(cols) == 0: return
colors = get_standard_colors(num_colors=len(cols))
# First axis
ax = data.loc[:, cols[0]].plot(label=cols[0], color=colors[0], **kwargs)
ax.set_ylabel(ylabel=cols[0])
lines, labels = ax.get_legend_handles_labels()
for n in range(1, len(cols)):
# Multiple y-axes
ax_new = ax.twinx()
ax_new.spines['right'].set_position(('axes', 1 + spacing * (n - 1)))
data.loc[:, cols[n]].plot(ax=ax_new, label=cols[n], color=colors[n % len(colors)], **kwargs)
ax_new.set_ylabel(ylabel=cols[n])
# Proper legend position
line, label = ax_new.get_legend_handles_labels()
lines += line
labels += label
ax.legend(lines, labels, loc=0)
return ax
Example:
from random import randrange
data = pd.DataFrame(dict(
s1=[randrange(-1000, 1000) for _ in range(100)],
s2=[randrange(-100, 100) for _ in range(100)],
s3=[randrange(-10, 10) for _ in range(100)],
))
plot_multi(data.cumsum(), figsize=(10, 5))
Output:
I modified the above answer a bit to make it accept custom x column, well-documented, and more flexible.
You can copy this snippet and use it as a function:
from typing import List, Union
import matplotlib.axes
import pandas as pd
def plot_multi(
data: pd.DataFrame,
x: Union[str, None] = None,
y: Union[List[str], None] = None,
spacing: float = 0.1,
**kwargs
) -> matplotlib.axes.Axes:
"""Plot multiple Y axes on the same chart with same x axis.
Args:
data: dataframe which contains x and y columns
x: column to use as x axis. If None, use index.
y: list of columns to use as Y axes. If None, all columns are used
except x column.
spacing: spacing between the plots
**kwargs: keyword arguments to pass to data.plot()
Returns:
a matplotlib.axes.Axes object returned from data.plot()
Example:
>>> plot_multi(df, figsize=(22, 10))
>>> plot_multi(df, x='time', figsize=(22, 10))
>>> plot_multi(df, y='price qty value'.split(), figsize=(22, 10))
>>> plot_multi(df, x='time', y='price qty value'.split(), figsize=(22, 10))
>>> plot_multi(df[['time price qty'.split()]], x='time', figsize=(22, 10))
See Also:
This code is mentioned in https://stackoverflow.com/q/11640243/2593810
"""
from pandas.plotting._matplotlib.style import get_standard_colors
# Get default color style from pandas - can be changed to any other color list
if y is None:
y = data.columns
# remove x_col from y_cols
if x:
y = [col for col in y if col != x]
if len(y) == 0:
return
colors = get_standard_colors(num_colors=len(y))
if "legend" not in kwargs:
kwargs["legend"] = False # prevent multiple legends
# First axis
ax = data.plot(x=x, y=y[0], color=colors[0], **kwargs)
ax.set_ylabel(ylabel=y[0])
lines, labels = ax.get_legend_handles_labels()
for i in range(1, len(y)):
# Multiple y-axes
ax_new = ax.twinx()
ax_new.spines["right"].set_position(("axes", 1 + spacing * (i - 1)))
data.plot(
ax=ax_new, x=x, y=y[i], color=colors[i % len(colors)], **kwargs
)
ax_new.set_ylabel(ylabel=y[i])
# Proper legend position
line, label = ax_new.get_legend_handles_labels()
lines += line
labels += label
ax.legend(lines, labels, loc=0)
return ax
Here's one way to use it:
plot_multi(df, x='time', y='price qty value'.split(), figsize=(22, 10))

Multiple figures in a single window

I want to create a function which plot on screen a set of figures in a single window. By now I write this code:
import pylab as pl
def plot_figures(figures):
"""Plot a dictionary of figures.
Parameters
----------
figures : <title, figure> dictionary
"""
for title in figures:
pl.figure()
pl.imshow(figures[title])
pl.gray()
pl.title(title)
pl.axis('off')
It works perfectly but I would like to have the option for plotting all the figures in single window. And this code doesn't. I read something about subplot but it looks quite tricky.
You can define a function based on the subplots command (note the s at the end, different from the subplot command pointed by urinieto) of matplotlib.pyplot.
Below is an example of such a function, based on yours, allowing to plot multiples axes in a figure. You can define the number of rows and columns you want in the figure layout.
def plot_figures(figures, nrows = 1, ncols=1):
"""Plot a dictionary of figures.
Parameters
----------
figures : <title, figure> dictionary
ncols : number of columns of subplots wanted in the display
nrows : number of rows of subplots wanted in the figure
"""
fig, axeslist = plt.subplots(ncols=ncols, nrows=nrows)
for ind,title in enumerate(figures):
axeslist.ravel()[ind].imshow(figures[title], cmap=plt.gray())
axeslist.ravel()[ind].set_title(title)
axeslist.ravel()[ind].set_axis_off()
plt.tight_layout() # optional
Basically, the function creates a number of axes in the figures, according to the number of rows (nrows) and columns (ncols) you want, and then iterates over the list of axis to plot your images and adds the title for each of them.
Note that if you only have one image in your dictionary, your previous syntax plot_figures(figures) will work since nrows and ncols are set to 1 by default.
An example of what you can obtain:
import matplotlib.pyplot as plt
import numpy as np
# generation of a dictionary of (title, images)
number_of_im = 6
figures = {'im'+str(i): np.random.randn(100, 100) for i in range(number_of_im)}
# plot of the images in a figure, with 2 rows and 3 columns
plot_figures(figures, 2, 3)
You should use subplot.
In your case, it would be something like this (if you want them one on top of the other):
fig = pl.figure(1)
k = 1
for title in figures:
ax = fig.add_subplot(len(figures),1,k)
ax.imshow(figures[title])
ax.gray()
ax.title(title)
ax.axis('off')
k += 1
Check out the documentation for other options.
If you want to group multiple figures in one window you can do smth. like this:
import matplotlib.pyplot as plt
import numpy as np
img = plt.imread('C:/.../Download.jpg') # Path to image
img = img[0:150,50:200,0] # Define image size to be square --> Or what ever shape you want
fig = plt.figure()
nrows = 10 # Define number of columns
ncols = 10 # Define number of rows
image_heigt = 150 # Height of the image
image_width = 150 # Width of the image
pixels = np.zeros((nrows*image_heigt,ncols*image_width)) # Create
for a in range(nrows):
for b in range(ncols):
pixels[a*image_heigt:a*image_heigt+image_heigt,b*image_heigt:b*image_heigt+image_heigt] = img
plt.imshow(pixels,cmap='jet')
plt.axis('off')
plt.show()
As result you receive:
Building on the answer from: How to display multiple images in one figure correctly?, here is another method:
import math
import numpy as np
import matplotlib.pyplot as plt
def plot_images(np_images, titles = [], columns = 5, figure_size = (24, 18)):
count = np_images.shape[0]
rows = math.ceil(count / columns)
fig = plt.figure(figsize=figure_size)
subplots = []
for index in range(count):
subplots.append(fig.add_subplot(rows, columns, index + 1))
if len(titles):
subplots[-1].set_title(str(titles[index]))
plt.imshow(np_images[index])
plt.show()
You can also do this:
import matplotlib.pyplot as plt
f, axarr = plt.subplots(1, len(imgs))
for i, img in enumerate(imgs):
axarr[i].imshow(img)
plt.suptitle("Your title!")
plt.show()
def plot_figures(figures, nrows=None, ncols=None):
if not nrows or not ncols:
# Plot figures in a single row if grid not specified
nrows = 1
ncols = len(figures)
else:
# check minimum grid configured
if len(figures) > nrows * ncols:
raise ValueError(f"Too few subplots ({nrows*ncols}) specified for ({len(figures)}) figures.")
fig = plt.figure()
# optional spacing between figures
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for index, title in enumerate(figures):
plt.subplot(nrows, ncols, index + 1)
plt.title(title)
plt.imshow(figures[title])
plt.show()
Any grid configuration (or none) can be specified as long as the product of the number of rows and the number of columns is equal to or greater than the number of figures.
For example, for len(figures) == 10, these are acceptable
plot_figures(figures)
plot_figures(figures, 2, 5)
plot_figures(figures, 3, 4)
plot_figures(figures, 4, 3)
plot_figures(figures, 5, 2)
import numpy as np
def save_image(data, ws=0.1, hs=0.1, sn='save_name'):
import matplotlib.pyplot as plt
m = n = int(np.sqrt(data.shape[0])) # (36, 1, 32, 32)
fig, ax = plt.subplots(m,n, figsize=(m*6,n*6))
ax = ax.ravel()
for i in range(data.shape[0]):
ax[i].matshow(data[i,0,:,:])
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9,
top=0.9, wspace=ws, hspace=hs)
plt.tight_layout()
plt.savefig('{}.png'.format(sn))
data = np.load('img_test.npy')
save_image(data, ws=0.1, hs=0.1, sn='multiple_plot')

Categories