Plotting images side by side using matplotlib - python

I was wondering how I am able to plot images side by side using matplotlib for example something like this:
The closest I got is this:
This was produced by using this code:
f, axarr = plt.subplots(2,2)
axarr[0,0] = plt.imshow(image_datas[0])
axarr[0,1] = plt.imshow(image_datas[1])
axarr[1,0] = plt.imshow(image_datas[2])
axarr[1,1] = plt.imshow(image_datas[3])
But I can't seem to get the other images to show. I'm thinking that there must be a better way to do this as I would imagine trying to manage the indexes would be a pain. I have looked through the documentation although I have a feeling I may be look at the wrong one. Would anyone be able to provide me with an example or point me in the right direction?
EDIT:
See the answer from #duhaime if you want a function to automatically determine the grid size.

The problem you face is that you try to assign the return of imshow (which is an matplotlib.image.AxesImage to an existing axes object.
The correct way of plotting image data to the different axes in axarr would be
f, axarr = plt.subplots(2,2)
axarr[0,0].imshow(image_datas[0])
axarr[0,1].imshow(image_datas[1])
axarr[1,0].imshow(image_datas[2])
axarr[1,1].imshow(image_datas[3])
The concept is the same for all subplots, and in most cases the axes instance provide the same methods than the pyplot (plt) interface.
E.g. if ax is one of your subplot axes, for plotting a normal line plot you'd use ax.plot(..) instead of plt.plot(). This can actually be found exactly in the source from the page you link to.

One thing that I found quite helpful to use to print all images :
_, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
axs = axs.flatten()
for img, ax in zip(imgs, axs):
ax.imshow(img)
plt.show()

You are plotting all your images on one axis. What you want ist to get a handle for each axis individually and plot your images there. Like so:
fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax1.imshow(...)
ax2 = fig.add_subplot(2,2,2)
ax2.imshow(...)
ax3 = fig.add_subplot(2,2,3)
ax3.imshow(...)
ax4 = fig.add_subplot(2,2,4)
ax4.imshow(...)
For more info have a look here: http://matplotlib.org/examples/pylab_examples/subplots_demo.html
For complex layouts, you should consider using gridspec: http://matplotlib.org/users/gridspec.html

If the images are in an array and you want to iterate through each element and print it, you can write the code as follows:
plt.figure(figsize=(10,10)) # specifying the overall grid size
for i in range(25):
plt.subplot(5,5,i+1) # the number of images in the grid is 5*5 (25)
plt.imshow(the_array[i])
plt.show()
Also note that I used subplot and not subplots. They're both different

Below is a complete function show_image_list() that displays images side-by-side in a grid. You can invoke the function with different arguments.
Pass in a list of images, where each image is a Numpy array. It will create a grid with 2 columns by default. It will also infer if each image is color or grayscale.
list_images = [img, gradx, grady, mag_binary, dir_binary]
show_image_list(list_images, figsize=(10, 10))
Pass in a list of images, a list of titles for each image, and other arguments.
show_image_list(list_images=[img, gradx, grady, mag_binary, dir_binary],
list_titles=['original', 'gradx', 'grady', 'mag_binary', 'dir_binary'],
num_cols=3,
figsize=(20, 10),
grid=False,
title_fontsize=20)
Here's the code:
import matplotlib.pyplot as plt
import numpy as np
def img_is_color(img):
if len(img.shape) == 3:
# Check the color channels to see if they're all the same.
c1, c2, c3 = img[:, : , 0], img[:, :, 1], img[:, :, 2]
if (c1 == c2).all() and (c2 == c3).all():
return True
return False
def show_image_list(list_images, list_titles=None, list_cmaps=None, grid=True, num_cols=2, figsize=(20, 10), title_fontsize=30):
'''
Shows a grid of images, where each image is a Numpy array. The images can be either
RGB or grayscale.
Parameters:
----------
images: list
List of the images to be displayed.
list_titles: list or None
Optional list of titles to be shown for each image.
list_cmaps: list or None
Optional list of cmap values for each image. If None, then cmap will be
automatically inferred.
grid: boolean
If True, show a grid over each image
num_cols: int
Number of columns to show.
figsize: tuple of width, height
Value to be passed to pyplot.figure()
title_fontsize: int
Value to be passed to set_title().
'''
assert isinstance(list_images, list)
assert len(list_images) > 0
assert isinstance(list_images[0], np.ndarray)
if list_titles is not None:
assert isinstance(list_titles, list)
assert len(list_images) == len(list_titles), '%d imgs != %d titles' % (len(list_images), len(list_titles))
if list_cmaps is not None:
assert isinstance(list_cmaps, list)
assert len(list_images) == len(list_cmaps), '%d imgs != %d cmaps' % (len(list_images), len(list_cmaps))
num_images = len(list_images)
num_cols = min(num_images, num_cols)
num_rows = int(num_images / num_cols) + (1 if num_images % num_cols != 0 else 0)
# Create a grid of subplots.
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
# Create list of axes for easy iteration.
if isinstance(axes, np.ndarray):
list_axes = list(axes.flat)
else:
list_axes = [axes]
for i in range(num_images):
img = list_images[i]
title = list_titles[i] if list_titles is not None else 'Image %d' % (i)
cmap = list_cmaps[i] if list_cmaps is not None else (None if img_is_color(img) else 'gray')
list_axes[i].imshow(img, cmap=cmap)
list_axes[i].set_title(title, fontsize=title_fontsize)
list_axes[i].grid(grid)
for i in range(num_images, len(list_axes)):
list_axes[i].set_visible(False)
fig.tight_layout()
_ = plt.show()

As per matplotlib's suggestion for image grids:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(4., 4.))
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(2, 2), # creates 2x2 grid of axes
axes_pad=0.1, # pad between axes in inch.
)
for ax, im in zip(grid, image_data):
# Iterating over the grid returns the Axes.
ax.imshow(im)
plt.show()

I end up at this url about once a week. For those who want a little function that just plots a grid of images without hassle, here we go:
import matplotlib.pyplot as plt
import numpy as np
def plot_image_grid(images, ncols=None, cmap='gray'):
'''Plot a grid of images'''
if not ncols:
factors = [i for i in range(1, len(images)+1) if len(images) % i == 0]
ncols = factors[len(factors) // 2] if len(factors) else len(images) // 4 + 1
nrows = int(len(images) / ncols) + int(len(images) % ncols)
imgs = [images[i] if len(images) > i else None for i in range(nrows * ncols)]
f, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 2*nrows))
axes = axes.flatten()[:len(imgs)]
for img, ax in zip(imgs, axes.flatten()):
if np.any(img):
if len(img.shape) > 2 and img.shape[2] == 1:
img = img.squeeze()
ax.imshow(img, cmap=cmap)
# make 16 images with 60 height, 80 width, 3 color channels
images = np.random.rand(16, 60, 80, 3)
# plot them
plot_image_grid(images)

Sample code to visualize one random image from the dataset
def get_random_image(num):
path=os.path.join("/content/gdrive/MyDrive/dataset/",images[num])
image=cv2.imread(path)
return image
Call the function
images=os.listdir("/content/gdrive/MyDrive/dataset")
random_num=random.randint(0, len(images))
img=get_random_image(random_num)
plt.figure(figsize=(8,8))
plt.imshow(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
Display cluster of random images from the given dataset
#Making a figure containing 16 images
lst=random.sample(range(0,len(images)), 16)
plt.figure(figsize=(12,12))
for index,value in enumerate(lst):
img=get_random_image(value)
img_resized=cv2.resize(img,(400,400))
#print(path)
plt.subplot(4,4,index+1)
plt.imshow(img_resized)
plt.axis('off')
plt.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)
#plt.savefig(f"Images/{lst[0]}.png")
plt.show()

Plotting images present in a dataset
Here rand gives a random index value which is used to select a random image present in the dataset and labels has the integer representation for every image type and labels_dict is a dictionary holding key val information
fig,ax = plt.subplots(5,5,figsize = (15,15))
ax = ax.ravel()
for i in range(25):
rand = np.random.randint(0,len(image_dataset))
image = image_dataset[rand]
ax[i].imshow(image,cmap = 'gray')
ax[i].set_title(labels_dict[labels[rand]])
plt.show()

Related

python matplotlib shared xlabel description / title on multiple subplots for animation

I'm using the following code to produce an animation with matplotlib that is intended to visualize my experiments.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation, PillowWriter
plt.rcParams['animation.html'] = 'jshtml'
def make_grid(X, description=None, labels=None, title_fmt="label: {}", cmap='gray', ncols=3, colors=None):
L = len(X)
nrows = -(-L // ncols)
frame_plot = []
for i in range(L):
plt.subplot(nrows, ncols, i + 1)
im = plt.imshow(X[i].squeeze(), cmap=cmap, interpolation='none')
if labels is not None:
color = 'k' if colors is None else colors[i]
plt.title(title_fmt.format(labels[i]), color=color)
plt.xticks([])
plt.yticks([])
frame_plot.append(im)
return frame_plot
def animate_step(X):
return X ** 2
n_splots = 6
X = np.random.random((n_splots,32,32,3))
Y = X
X_t = []
for i in range(10):
Y = animate_step(Y)
X_t.append((Y, i))
frames = []
for X, step in X_t:
frame = make_grid(X,
description="step={}".format(step),
labels=range(n_splots),
title_fmt="target: {}")
frames.append(frame)
anim = ArtistAnimation(plt.gcf(), frames,
interval=300, repeat_delay=8000, blit=True)
plt.close()
anim.save("test.gif", writer=PillowWriter())
anim
The result can be seen here:
https://i.stack.imgur.com/OaOsf.gif
It works fine so far, but I'm having trouble getting a shared xlabel to add a description for all of the 6 subplots in the animation. It is supposed to show what step the image is on, i.e. "step=5".
Since it is an animation, I cannot use xlabel or set_title (since it would be constant over the whole animation) and have to draw the text myself.
I've tried something along the lines of..
def make_grid(X, description=None, labels=None, title_fmt="label: {}", cmap='gray', ncols=3, colors=None):
L = len(X)
nrows = -(-L // ncols)
frame_plot = []
desc = plt.text(0.5, .04, description,
size=plt.rcparams["axes.titlesize"],
ha="center",
transform=plt.gca().transAxes
)
frame_plot.append(desc)
...
This, of course, won't work, because the axes are not yet created. I tried using the axis of another subplot(nrows, 1, nrows), but then the existing images are drawn over..
Does anyone have a solution to this?
Edit:
unclean, hacky solution for now:
Wait for the axes of the middle image of the last row to be created and use that for plotting the text.
In the for loop:
...
if i == int((nrows - 0.5) * ncols):
title = ax.text(0.25, -.3, description,
size=plt.rcParams["axes.titlesize"],
# ha="center",
transform=ax.transAxes
)
frame_plot.append(title)
...
To me, your case is easier to solve with FuncAnimation instead of ArtistAnimation, even if you already have access to the full list of data you want to show animated (see this thread for a discussion about the difference between the two functions).
Inspired from this FuncAnimation example, I wrote the code below that does what you needed (using the same code with ArtistAnimation and correct list of arguments does not work).
The main idea is to initialize all elements to be animated at the beginning, and to update them over the animation frames. This can be done for the text object (step_txt = fig.text(...)) in charge of displaying the current step, and for the images out from ax.imshow. You can then update whatever object you would like to see animated with this recipe.
Note that the technique works if you want the text to be an x_label or any text you choose to show. See the commented line in the code.
#!/Users/seydoux/anaconda3/envs/jupyter/bin/python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
# parameters
n_frames = 10
n_splots = 6
n_cols = 3
n_rows = n_splots // n_cols
def update_data(x):
return x ** 2
# create all snapshots
snapshots = [np.random.rand(n_splots, 32, 32, 3)]
for _ in range(n_frames):
snapshots.append(update_data(snapshots[-1]))
# initialize figure and static elements
fig, axes = plt.subplots(2, 3)
axes = axes.ravel() # so we can access all axes with a single index
for i, ax in enumerate(axes):
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("target: {}".format(i))
# initialize elements to be animated
step_txt = fig.text(0.5, 0.95, "step: 0", ha="center", weight="bold")
# step_txt = axes[4].set_xlabel("step: 0") # also works with x_label
imgs = list()
for a, s in zip(axes, snapshots[0]):
imgs.append(a.imshow(s, interpolation="none", cmap="gray"))
# animation function
def animate(i):
# update images
for img, s in zip(imgs, snapshots[i]):
img.set_data(s)
# update text
step_txt.set_text("step: {}".format(i))
# etc
anim = FuncAnimation(fig, animate, frames=n_frames, interval=300)
anim.save("test.gif", writer=PillowWriter())
Here is the output I got from the above code:

Adding a colorbar whose color corresponds to the different lines in an existing plot

My dataset is in the form of :
Data[0] = [headValue,x0,x1,..xN]
Data[1] = [headValue_ya,ya0,ya1,..yaN]
Data[2] = [headValue_yb,yb0,yb1,..ybN]
...
Data[n] = [headvalue_yz,yz0,yz1,..yzN]
I want to plot f(y*) = x, so I can visualize all Lineplots in the same figure with different colors, each color determined by the headervalue_y*.
I also want to add a colorbar whose color matching the lines and therefore the header values, so we can link visually which header value leads to which behaviour.
Here is what I am aiming for :(Plot from Lacroix B, Letort G, Pitayu L, et al. Microtubule Dynamics Scale with Cell Size to Set Spindle Length and Assembly Timing. Dev Cell. 2018;45(4):496–511.e6. doi:10.1016/j.devcel.2018.04.022)
I have trouble adding the colorbar, I have tried to extract N colors from a colormap (N is my number of different headValues, or column -1) and then adding for each line plot the color corresponding here is my code to clarify:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
Data = [['Time',0,0.33,..200],[0.269,4,4.005,...11],[0.362,4,3.999,...16.21],...[0.347,4,3.84,...15.8]]
headValues = [0.269,0.362,0.335,0.323,0.161,0.338,0.341,0.428,0.245,0.305,0.305,0.314,0.299,0.395,0.32,0.437,0.203,0.41,0.392,0.347]
# the differents headValues_y* of each column here in a list but also in Data
# with headValue[0] = Data[1][0], headValue[1] = Data[2][0] ...
cmap = mpl.cm.get_cmap('rainbow') # I choose my colormap
rgba = [] # the color container
for value in headValues:
rgba.append(cmap(value)) # so rgba will contain a different color for each headValue
fig, (ax,ax1) = plt.subplots(2,1) # creating my figure and two axes to put the Lines and the colorbar
c = 0 # index for my colors
for i in range(1, len(Data)):
ax.plot( Data[0][1:], Data[i][1:] , color = rgba[c])
# Data[0][1:] is x, Data[i][1:] is y, and the color associated with Data[i][0]
c += 1
fig.colorbar(mpl.cm.ScalarMappable(cmap= mpl.colors.ListedColormap(rgba)), cax=ax1, orientation='horizontal')
# here I create my scalarMappable for my lineplot and with the previously selected colors 'rgba'
plt.show()
The current result:
How to add the colorbar on the side or the bottom of the first axis ?
How to properly add a scale to this colorbar correspondig to different headValues ?
How to make the colorbar scale and colors match to the different lines on the plot with the link One color = One headValue ?
I have tried to work with scatterplot which are more convenient to use with scalarMappable but no solution allows me to do all these things at once.
Here is a possible approach. As the 'headValues' aren't sorted, nor equally spaced and one is even used twice, it is not fully clear what the most-desired result would be.
Some remarks:
The standard way of creating a colorbar in matplotlib doesn't need a separate subplot. Matplotlib will reduce the existing plot a bit and put the colorbar next to it (or below for a vertical bar).
Converting the 'headValues' to a numpy array allows for compact code, e.g. writing rgba = cmap(headValues) directly calculates the complete array.
Calling cmap on unchanged values will map 0 to the lowest color and 1 to the highest color, so for values only between 0.16 and 0.44 they all will be mapped to quite similar colors. One approach is to create a norm to map 0.16 to the lowest color and 0.44 to the highest. In code: norm = plt.Normalize(headValues.min(), headValues.max()) and then calculate rgba = cmap(norm(headValues)).
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
headValues = np.array([0.269, 0.362, 0.335, 0.323, 0.161, 0.338, 0.341, 0.428, 0.245, 0.305, 0.305, 0.314, 0.299, 0.395, 0.32, 0.437, 0.203, 0.41, 0.392, 0.347])
x = np.linspace(0, 200, 500)
# create Data similar to the data in the question
Data = [['Time'] + list(x)] + [[val] + list(np.sqrt(4 * x) * val + 4) for val in headValues]
headValues = np.array([d[0] for d in Data[1:]])
order = np.argsort(headValues)
inverse_order = np.argsort(order)
cmap = mpl.cm.get_cmap('rainbow')
rgba = cmap(np.linspace(0, 1, len(headValues))) # evenly spaced colors
fig, ax = plt.subplots(1, 1)
for i in range(1, len(Data)):
ax.plot(Data[0][1:], Data[i][1:], color=rgba[inverse_order[i-1]])
# Data[0][1:] is x, Data[i][1:] is y, and the color associated with Data[i-1][0]
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=mpl.colors.ListedColormap(rgba)), orientation='vertical',
ticks=np.linspace(0, 1, len(rgba) * 2 + 1)[1::2])
cbar.set_ticklabels(headValues[order])
plt.show()
Alternatively, the colors can be assigned using their position in the colormap, but without creating
cmap = mpl.cm.get_cmap('rainbow')
norm = plt.Normalize(headValues.min(), headValues.max())
fig, ax = plt.subplots(1, 1)
for i in range(1, len(Data)):
ax.plot(Data[0][1:], Data[i][1:], color=cmap(norm(Data[i][0])))
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap, norm=norm))
To get ticks for each of the 'headValues', these ticks can be set explicitly. As putting a label for each tick will result in overlapping text, labels that are too close to other labels can be replaced by an empty string:
headValues.sort()
cbar2 = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap, norm=norm), ticks=headValues)
cbar2.set_ticklabels([val if val < next - 0.007 else '' for val, next in zip(headValues[:-1], headValues[1:])]
+ [headValues[-1]])
At the left the result of the first approach (colors in segments), at the right the alternative colorbars (color depending on value):

create boxes on graph in python

I want this kind of result. I want my code to read elements of a text file and if element=='healthy'
it should create a box in a graph and its color should be green ('healthy written in box').
else if element=='unhealthy'
it should create a box and its color should be red (with 'unhealthy written in box').
boxes should be horizontally aligned, and if more than 5 then remaining should start from the next row. (every row should contain only 5 boxes or less).
The end result should display a graph that contains boxes,
red denoting 'unhealthy' and green denoting 'healthy'
I found the following code, but it is not working they way I want it to.
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
import numpy as np
from matplotlib import colors
#open text file (percen) that contains healthy/unhealthy
with open('percen.txt', 'r') as f:
result= [int(line) for line in f]
data = np.random.rand(10,10) * 20
cmap = colors.ListedColormap(['green'])
cmap1 = colors.ListedColormap(['red'])
bounds = [0,10,20]
norm = colors.BoundaryNorm(bounds, cmap.N)
fig, ax = plt.subplots(2,5 , sharex='col', sharey='row')
for i in range(2):
for j in range(5):
for element in result:
if (element=='healthy'):
ax[i,j].text(1, -3, 'healthy',
fontsize=15, ha='center', color='green')
ax[i,j].imshow(data,cmap=cmap, norm=norm)
else:
ax[i,j].text(1, -3, 'unhealthy',
fontsize=15, ha='center', color='red')
ax[i,j].imshow(data,cmap=cmap1,norm=norm)
fig
plt.show()
There are a few different ways you can do this and your code is probably not the best but we can use it as a starting point. Your issue is that you are looping through the plots and then looping through your data again for each plot. Your current code also adds text above the plot. If you want the text above I would recommend adding the label as a title, otherwise when you set your text inside the plot you need to specify the coordinates within the grid.
Below is a modified form of your code, play around with it some more to get what you want.
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
import numpy as np
from matplotlib import colors
result = ['healthy', 'unhealthy', 'healthy', 'unhealthy', 'healthy', 'unhealthy', 'healthy', 'healthy', 'unhealthy', 'unhealthy']
data = np.random.rand(10,10) * 20
cmap = colors.ListedColormap(['green'])
cmap1 = colors.ListedColormap(['red'])
bounds = [0,10,20]
norm = colors.BoundaryNorm(bounds, cmap.N)
fig, ax = plt.subplots(2,5 , sharex='col', sharey='row',figsize=(15,8)) # Added figsize to better show your plot
element_index = 0
for i in range(2):
for j in range(5):
element = result[element_index] #Instead of the for loop, get the corresponding element
if (element=='healthy'):
ax[i,j].text(4.5,4.5, 'healthy',fontsize=15, ha='center' ,color='black',zorder=100) #Change zorder so label is over plot
ax[i,j].imshow(data,cmap=cmap, norm=norm)
ax[i,j].set_yticklabels('') #To remove arbitrary numbers on y axis
ax[i,j].set_xticklabels('') #To remove arbitrary numbers on y axis
elif element == 'unhealthy':
ax[i,j].text(4.5,4.5, 'unhealthy',fontsize=15, ha='center' ,color='black',zorder=100)
ax[i,j].imshow(data,cmap=cmap1,norm=norm)
ax[i,j].set_yticklabels('') #To remove arbitrary numbers on y axis
ax[i,j].set_xticklabels('') #To remove arbitrary numbers on x axis
element_index+=1 #Add 1 to the index so we get the next value for the next plot
fig
plt.show()

matplotlib: change axis ticks of ndim histogram plotted with seaborn.heatmap

Motivation:
I'm trying to visualize a dataset of many n-dimensional vectors (let's say i have 10k vectors with n=300 dimensions). What i'd like to do is calculate a histogram for each of the n dimensions and plot it as a single line in a bins*n heatmap.
So far i've got this:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)
def ndhist(vectors, bins=500):
limits = (vectors.min(), vectors.max())
hists = []
dims = vectors.shape[1]
for dim in range(dims):
h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
hists.append(h)
hists = np.array(hists)
fig = plt.figure(figsize=(16, 9))
sns.heatmap(hists)
axes = fig.gca()
axes.set(ylabel='dimensions', xlabel='values')
print(dims)
print(limits)
ndhist(vectors)
This generates the following output:
300
(-6.538069472429366, 6.52159540162285)
Problem / Question:
How can i change the axes ticks?
for the y-axis i'd like to simply change this back to matplotlib's default, so it picks nice ticks like 0, 50, 100, ..., 250 (bonus points for 299 or 300)
for the x-axis i'd like to convert the shown bin indices into the bin (left) boundaries, then, as above, i'd like to change this back to matplotlib's default selection of some "nice" ticks like -5, -2.5, 0, 2.5, 5 (bonus points for also including the actual limits -6.538, 6.522)
Own solution attempts:
I've tried many things like the following already:
def ndhist_axlabels(vectors, bins=500):
limits = (vectors.min(), vectors.max())
hists = []
dims = vectors.shape[1]
for dim in range(dims):
h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
hists.append(h)
hists = np.array(hists)
fig = plt.figure(figsize=(16, 9))
sns.heatmap(hists, yticklabels=False, xticklabels=False)
axes = fig.gca()
axes.set(ylabel='dimensions', xlabel='values')
#plt.xticks(np.linspace(*limits, len(bins)), bins)
plt.xticks(range(len(bins)), bins)
axes.xaxis.set_major_locator(matplotlib.ticker.AutoLocator())
plt.yticks(range(dims+1), range(dims+1))
axes.yaxis.set_major_locator(matplotlib.ticker.AutoLocator())
print(dims)
print(limits)
ndhist_axlabels(vectors)
As you can see however, the axes labels are pretty wrong. My guess is that the extent or limits are somewhere stored in the original axis, but lost when switching back to the AutoLocator. Would greatly appreciate a nudge in the right direction.
Maybe you're overthinking this. To plot image data, one can use imshow and get the ticking and formatting for free.
import numpy as np
from matplotlib import pyplot as plt
# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)
def ndhist(vectors, bins=500):
limits = (vectors.min(), vectors.max())
hists = []
dims = vectors.shape[1]
for dim in range(dims):
h, _ = np.histogram(vectors[:, dim], bins=bins, range=limits)
hists.append(h)
hists = np.array(hists)
fig, ax = plt.subplots(figsize=(16, 9))
extent = [limits[0], limits[-1], hists.shape[0]-0.5, -0.5]
im = ax.imshow(hists, extent=extent, aspect="auto")
fig.colorbar(im)
ax.set(ylabel='dimensions', xlabel='values')
ndhist(vectors)
plt.show()
If you read the docs, you will notice that the xticklabels/yticklabels arguments are overloaded, such that if you provide an integer instead of a string, it will interpret the argument as xtickevery/ytickevery and place ticks only at the corresponding locations. So in your case, seaborn.heatmap(hists, yticklabels=50) fixes your y-axis problem.
Regarding your xtick labels, I would simply provide them explictly:
xtickevery = 50
xticklabels = ['{:.1f}'.format(b) if ii%xtickevery == 0 else '' for ii, b in enumerate(bins)]
sns.heatmap(hists, yticklabels=50, xticklabels=xticklabels)
Finally came up with a version that works for me for now and uses AutoLocator based on some simple linear mapping...
def ndhist(vectors, bins=1000, title=None):
t = time.time()
limits = (vectors.min(), vectors.max())
hists = []
dims = vectors.shape[1]
for dim in range(dims):
h, bs = np.histogram(vectors[:, dim], bins=bins, range=limits)
hists.append(h)
hists = np.array(hists)
fig = plt.figure(figsize=(16, 12))
sns.heatmap(
hists,
yticklabels=50,
xticklabels=False
)
axes = fig.gca()
axes.set(
ylabel=f'dimensions ({dims} total)',
xlabel=f'values (min: {limits[0]:.4g}, max: {limits[1]:.4g}, {bins} bins)',
title=title,
)
def val_to_idx(val):
# calc (linearly interpolated) index loc for given val
return bins*(val - limits[0])/(limits[1] - limits[0])
xlabels = [round(l, 3) for l in limits] + [
v for v in matplotlib.ticker.AutoLocator().tick_values(*limits)[1:-1]
]
# drop auto-gen labels that might be too close to limits
d = (xlabels[4] - xlabels[3])/3
if (xlabels[1] - xlabels[-1]) < d:
del xlabels[-1]
if (xlabels[2] - xlabels[0]) < d:
del xlabels[2]
xticks = [val_to_idx(val) for val in xlabels]
axes.set_xticks(xticks)
axes.set_xticklabels([f'{l:.4g}' for l in xlabels])
plt.show()
print(f'histogram generated in {time.time() - t:.2f}s')
ndhist(np.random.randn(100000, 300), bins=1000, title='randn')
Thanks to Paul for his answer giving me the idea.
If there's an easier or more elegant solution, i'd still be interested though.

Multiple figures in a single window

I want to create a function which plot on screen a set of figures in a single window. By now I write this code:
import pylab as pl
def plot_figures(figures):
"""Plot a dictionary of figures.
Parameters
----------
figures : <title, figure> dictionary
"""
for title in figures:
pl.figure()
pl.imshow(figures[title])
pl.gray()
pl.title(title)
pl.axis('off')
It works perfectly but I would like to have the option for plotting all the figures in single window. And this code doesn't. I read something about subplot but it looks quite tricky.
You can define a function based on the subplots command (note the s at the end, different from the subplot command pointed by urinieto) of matplotlib.pyplot.
Below is an example of such a function, based on yours, allowing to plot multiples axes in a figure. You can define the number of rows and columns you want in the figure layout.
def plot_figures(figures, nrows = 1, ncols=1):
"""Plot a dictionary of figures.
Parameters
----------
figures : <title, figure> dictionary
ncols : number of columns of subplots wanted in the display
nrows : number of rows of subplots wanted in the figure
"""
fig, axeslist = plt.subplots(ncols=ncols, nrows=nrows)
for ind,title in enumerate(figures):
axeslist.ravel()[ind].imshow(figures[title], cmap=plt.gray())
axeslist.ravel()[ind].set_title(title)
axeslist.ravel()[ind].set_axis_off()
plt.tight_layout() # optional
Basically, the function creates a number of axes in the figures, according to the number of rows (nrows) and columns (ncols) you want, and then iterates over the list of axis to plot your images and adds the title for each of them.
Note that if you only have one image in your dictionary, your previous syntax plot_figures(figures) will work since nrows and ncols are set to 1 by default.
An example of what you can obtain:
import matplotlib.pyplot as plt
import numpy as np
# generation of a dictionary of (title, images)
number_of_im = 6
figures = {'im'+str(i): np.random.randn(100, 100) for i in range(number_of_im)}
# plot of the images in a figure, with 2 rows and 3 columns
plot_figures(figures, 2, 3)
You should use subplot.
In your case, it would be something like this (if you want them one on top of the other):
fig = pl.figure(1)
k = 1
for title in figures:
ax = fig.add_subplot(len(figures),1,k)
ax.imshow(figures[title])
ax.gray()
ax.title(title)
ax.axis('off')
k += 1
Check out the documentation for other options.
If you want to group multiple figures in one window you can do smth. like this:
import matplotlib.pyplot as plt
import numpy as np
img = plt.imread('C:/.../Download.jpg') # Path to image
img = img[0:150,50:200,0] # Define image size to be square --> Or what ever shape you want
fig = plt.figure()
nrows = 10 # Define number of columns
ncols = 10 # Define number of rows
image_heigt = 150 # Height of the image
image_width = 150 # Width of the image
pixels = np.zeros((nrows*image_heigt,ncols*image_width)) # Create
for a in range(nrows):
for b in range(ncols):
pixels[a*image_heigt:a*image_heigt+image_heigt,b*image_heigt:b*image_heigt+image_heigt] = img
plt.imshow(pixels,cmap='jet')
plt.axis('off')
plt.show()
As result you receive:
Building on the answer from: How to display multiple images in one figure correctly?, here is another method:
import math
import numpy as np
import matplotlib.pyplot as plt
def plot_images(np_images, titles = [], columns = 5, figure_size = (24, 18)):
count = np_images.shape[0]
rows = math.ceil(count / columns)
fig = plt.figure(figsize=figure_size)
subplots = []
for index in range(count):
subplots.append(fig.add_subplot(rows, columns, index + 1))
if len(titles):
subplots[-1].set_title(str(titles[index]))
plt.imshow(np_images[index])
plt.show()
You can also do this:
import matplotlib.pyplot as plt
f, axarr = plt.subplots(1, len(imgs))
for i, img in enumerate(imgs):
axarr[i].imshow(img)
plt.suptitle("Your title!")
plt.show()
def plot_figures(figures, nrows=None, ncols=None):
if not nrows or not ncols:
# Plot figures in a single row if grid not specified
nrows = 1
ncols = len(figures)
else:
# check minimum grid configured
if len(figures) > nrows * ncols:
raise ValueError(f"Too few subplots ({nrows*ncols}) specified for ({len(figures)}) figures.")
fig = plt.figure()
# optional spacing between figures
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for index, title in enumerate(figures):
plt.subplot(nrows, ncols, index + 1)
plt.title(title)
plt.imshow(figures[title])
plt.show()
Any grid configuration (or none) can be specified as long as the product of the number of rows and the number of columns is equal to or greater than the number of figures.
For example, for len(figures) == 10, these are acceptable
plot_figures(figures)
plot_figures(figures, 2, 5)
plot_figures(figures, 3, 4)
plot_figures(figures, 4, 3)
plot_figures(figures, 5, 2)
import numpy as np
def save_image(data, ws=0.1, hs=0.1, sn='save_name'):
import matplotlib.pyplot as plt
m = n = int(np.sqrt(data.shape[0])) # (36, 1, 32, 32)
fig, ax = plt.subplots(m,n, figsize=(m*6,n*6))
ax = ax.ravel()
for i in range(data.shape[0]):
ax[i].matshow(data[i,0,:,:])
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9,
top=0.9, wspace=ws, hspace=hs)
plt.tight_layout()
plt.savefig('{}.png'.format(sn))
data = np.load('img_test.npy')
save_image(data, ws=0.1, hs=0.1, sn='multiple_plot')

Categories