How to plot a Polycollection created with .fill_between() in matplotlib? - python

I have this code to plot
import matplotlib.pyplot as plt
import numpy as np
# Data
time = np.arange(0,20,0.05)
data = np.sin(time)
# Plot
main_plot = plt.plot(time, data)
error_plot = plt.fill_between(time, data+0.5, data-0.5, alpha=0.2)
plt.show()
and I would like to plot, after in the code, the same plot but using the main_plot and error_plot variable. So I have this
plt.plot(*main_plot[0].get_data())
# Here I want to plot the error_plot (PolyCollection type)
# but I don't know how. Something like
# plt.plot(error_plot)
I have tried this related post with no result. I've also tried
fig = plt.figure()
ax = fig.add_subplot(111)
ax.add_collection(error_plot)
but resulting in this error
1 fig = plt.figure()
2 ax = fig.add_subplot(111)
----> 3 ax.add_collection(error_plot)
RuntimeError: Can not put single artist in more than one figure

Related

ploting mutliple PSD with mne for python in same figure

I would like to plot multiple PSD obtained with plot_psd() from MNE python.
I tried the following code
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(3,1)
plt.figure()
ax = plt.axes()
# First plot
ax1 = fig.add_subplot(gs[0]
raw_egi.plot_psd(ax=ax1)
ax2=fig.add_subplot(gs[1]
raw_ws_ds_hp_lp.plot_psd(ax=ax2)
ax3= fig.add_subplot(gs[2]
raw_ws_ds_hp_lp_nf.plot_psd(ax=ax3)
plt.show()
It tells me that I have an invalid syntax.
The following code is working but all plots are superimposed
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(3,1)
plt.figure()
ax = plt.axes()
# First plot
raw_egi.plot_psd(ax=ax)
raw_ws_ds_hp_lp.plot_psd(ax=ax)
raw_ws_ds_hp_lp_nf.plot_psd(ax=ax)
plt.show()
Could you tell me ho to plot these 3 figures without superimposing but vertically (one by row). Bellow you will find the figure with the working code (i.e. 3 superimposed plots) Thanks for your help
Here is how I solve the question for 2 plots
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2)
raw_bp.plot_psd(ax=ax[0], show=False)
raw_bp_nf.plot_psd(ax=ax[1], show=False)
ax[0].set_title('PSD before filtering')
ax[1].set_title('PSD after filtering')
ax[1].set_xlabel('Frequency (Hz)')
fig.set_tight_layout(True)
plt.show()

How can I add jitter to my seaborn and matplot plots?

I am working on trying to add Jitter to my plots using seaborn and matplot plots. I am getting mixed information form what I am reading online. Some information is saying coding needs to be done and other information show it as being as simple as jitter = True. I there another library or something that I should be importing that I am not aware of? Below is the code that I am running and trying to add jitter to:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
filename = 'https://library.startlearninglabs.uw.edu/DATASCI410/Datasets/JitteredHeadCount.csv'
headcount_df = pd.read_csv(filename)
headcount_df.describe()
%matplotlib inline
ax = plt.figure(figsize=(12, 6)).gca() # define axis
headcount_df.plot.scatter(x = 'Hour', y = 'TablesOpen', ax = ax, alpha = 0.2)
# auto_price.plot(kind = 'scatter', x = 'city-mpg', y = 'price', ax = ax)
ax.set_title('Hour vs TablesOpen') # Give the plot a main title
ax.set_ylabel('TablesOpen')# Set text for y axis
ax.set_xlabel('Hour')
ax = sns.kdeplot(headcount_df.loc[:, ['TablesOpen', 'Hour']], shade = True, cmap = 'PuBu')
headcount_df.plot.scatter(x = 'Hour', y = 'TablesOpen', ax = ax, jitter = True)
ax.set_title('Hour vs TablesOpen') # Give the plot a main title
ax.set_ylabel('TablesOpen')# Set text for y axis
ax.set_xlabel('Hour')
I receive the error: AttributeError: 'PathCollection' object has no property 'jitter' when trying to add the jitter. Any help or more information on this would be much appreciated
To add jitter to a scatter plot, first get a handle to the collection that contains the scatter dots. When a scatter plot is just created on an ax, ax.collections[-1] will be the desired collection.
Calling get_offsets() on the collection gets all the xy coordinates of the dots. Add some small random number to each of them. As in this case all coordinates are integers, adding a random number between 0 and 1 spreads the dots out evenly.
In this case the number of dots is very huge. To better see where the dots are concentrated, they can be made very small (marker=',', linewidth=0, s=1,) and be very transparent (e.g.alpha=0.1).
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
filename = 'https://library.startlearninglabs.uw.edu/DATASCI410/Datasets/JitteredHeadCount.csv'
headcount_df = pd.read_csv(filename)
fig, ax = plt.subplots(figsize=(12, 6))
headcount_df.plot.scatter(x='Hour', y='TablesOpen', marker=',', linewidth=0, s=1, alpha=.1, color='crimson', ax=ax)
dots = ax.collections[-1]
offsets = dots.get_offsets()
jittered_offsets = offsets + np.random.uniform(0, 1, offsets.shape)
dots.set_offsets(jittered_offsets)
ax.set_title('Hour vs TablesOpen') # Give the plot a main title
ax.set_ylabel('TablesOpen') # Set text for y axis
ax.set_xlabel('Hour')
ax.set_xticks(range(25))
ax.autoscale(enable=True, tight=True)
plt.tight_layout()
plt.show()
As there are a huge number of points, drawing the 2D kde takes a long time. The time can be reduced by taking a random sample from the rows. Note that to draw a 2D kde, the latest versions of Seaborn want each column as a separate parameter.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
filename = 'https://library.startlearninglabs.uw.edu/DATASCI410/Datasets/JitteredHeadCount.csv'
headcount_df = pd.read_csv(filename)
fig, ax = plt.subplots(figsize=(12, 6))
N = 5000
rand_sel_df = headcount_df.iloc[np.random.choice(range(len(headcount_df)), N)]
ax = sns.kdeplot(rand_sel_df['Hour'], rand_sel_df['TablesOpen'], shade=True, cmap='PuBu', ax=ax)
ax.set_title('Hour vs TablesOpen')
ax.set_xticks(range(25))
plt.tight_layout()
plt.show()

Show legend that matplotlib dynamically created

My df has 4 columns: x, y, z, and grouping. I have created a 3D plot, with the assigned color of each point being decided by what grouping it belongs to in that row. For reference, a "grouping" can be any number from 1 to 6. The code is shown below:
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter3D(df.x, df.y, df.z, c=df.grouping)
plt.show()
I would like to show a legend on the plot that shows which color belongs to which grouping. Previously, I was using Seaborn for a 2D plot and the legend was automatically plotted. How can I add this feature with matplotlib?
If the values to be colormapped are numeric, the solution can be as simple as:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
a = np.random.rand(3,40)
c = np.random.randint(1,7, size=a.shape[1])
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
sc = ax.scatter3D(*a, c=c)
plt.legend(*sc.legend_elements())
plt.show()

Changing axis without changing data (Python)

How could I plot some data, remove the axis created by that data, and replace them with axis of a different scale?
Say I have something like:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
plt.xlim([0,5])
plt.ylim([0,5])
plt.plot([0,1,2,3,4,5])
plt.show()
This plots a line in a 5x5 plot with ranges from 0 to 5 on both axis. I would like to remove the 0 to 5 axis and say replace it with a -25 to 25 axis. This would just change the axis, but I don't want to move any of the data, i.e., it looks identical to the original plot just with different axis. I realize this can be simply done by shifting the data, but I do not wish to alter the data.
You could use plt.xticks to find the location of the labels, and then set the labels to 5 times the location values. The underlying data does not change; only the labels.
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
plt.xlim([0,5])
plt.ylim([0,5])
plt.plot([0,1,2,3,4,5])
locs, labels = plt.xticks()
labels = [float(item)*5 for item in locs]
plt.xticks(locs, labels)
plt.show()
yields
Alternatively, you could change the ticker formatter:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
N = 128
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(range(N+1))
plt.xlim([0,N])
plt.ylim([0,N])
ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: ('%g') % (x * 5.0)))
plt.show()

Python: 3D scatter losing colormap

I'm creating a 3D scatter plot with multiple sets of data and using a colormap for the whole figure. The code looks like this:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for R in [range(0,10), range(5,15), range(10,20)]:
data = [np.array(R), np.array(range(10)), np.array(range(10))]
AX = ax.scatter(*data, c=data[0], vmin=0, vmax=20, cmap=plt.cm.jet)
def forceUpdate(event): AX.changed()
fig.canvas.mpl_connect('draw_event', forceUpdate)
plt.colorbar(AX)
This works fine but as soon as I save it or rotate the plot, the colors on the first and second scatters turn blue.
The force update is working by keeping the colors but only on the last scatter plot drawn. I tried making a loop that updates all the scatter plots but I get the same result as above:
AX = []
for R in [range(0,10), range(5,15), range(10,20)]:
data = [np.array(R), np.array(range(10)), np.array(range(10))]
AX.append(ax.scatter(*data, c=data[0], vmin=0, vmax=20, cmap=plt.cm.jet))
for i in AX:
def forceUpdate(event): i.changed()
fig.canvas.mpl_connect('draw_event', forceUpdate)
Any idea how I can make sure all scatters are being updated so the colors don't disappear?
Thanks!
Having modified your code so that it does anything:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from mpl_toolkits.mplot3d import Axes3D
>>> AX = \[\]
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111, projection='3d')
>>> for R in \[range(0,10), range(5,15), range(10,20)\]:
... data = \[np.array(R), np.array(range(10)), np.array(range(10))\]
... AX = ax.scatter(*data, c=data\[0\], vmin=0, vmax=20, cmap=plt.cm.jet)
... def forceUpdate(event): AX.changed()
... fig.canvas.mpl_connect('draw_event', forceUpdate)
...
9
10
11
>>> plt.colorbar(AX)
<matplotlib.colorbar.Colorbar instance at 0x36265a8>
>>> plt.show()
then I get:
So the above code is working. If your existing code isn't then I suggest that you try the exact code above and if that doesn't work look into the versions of code that you are using if it does work then you will have to investigate the differences between it and your actual code, (rather than your example code).

Categories