Python: How to create a legend using an example - python

This is from Chapter 2 in the book Machine Learning In Action and I am trying to make the plot pictured here:
The author has posted the plot's code here, which I believe may be a bit hacky (he also mentions this code is sloppy since it is out of the book's scope).
Here is my attempt to re-create the plot:
First, the .txt file holding the data is as follows (source: "datingTestSet2.txt" in Ch.2 here):
40920 8.326976 0.953952 largeDoses
14488 7.153469 1.673904 smallDoses
26052 1.441871 0.805124 didntLike
75136 13.147394 0.428964 didntLike
38344 1.669788 0.134296 didntLike
...
Assume datingDataMat is a numpy.ndarray of shape `(1000L, 2L) where column 0 is "Frequent Flier Miles Per Year", column 1 is "% Time Playing Video Games", and column 2 is "liter of ice cream consumed per week", as shown in the sample above.
Assume datingLabels is a list of ints 1, 2, or 3 meaning "Did Not Like", "Liked in Small Doses", and "Liked in Large Doses" respectively - associated with column 3 above.
Here is the code I have to create the plot (full details for file2matrix are at the end):
datingDataMat,datingLabels = file2matrix("datingTestSet2.txt")
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot (111)
plt.xlabel("Freq flier miles")
plt.ylabel("% time video games")
# Not sure how to finish this: plt.legend([1, 2, 3], ["did not like", "small doses", "large doses"])
plt.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) # Change marker color and size
plt.show()
The output is here:
My main concern is how to create this legend. Is there a way to do this without needing a direct handle to the points?
Next, I am curious whether I can find a way to switch the colors to match those of the plot. Is there a way to do this without having some kind of "handle" on the individual points?
Also, if interested, here is the file2matrix implementation:
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines())
returnMat = np.zeros((numberOfLines,3)) #numpy.zeros(shape, dtype=float, order='C')
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3] # FFmiles/yr, % time gaming, L ice cream/wk
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector

Here's an example that mimics the code you already have that shows the approach described in Saullo Castro's example.
It also shows how to set the colors in the example.
If you want more information on the colors available, see the documentation at http://matplotlib.org/api/colors_api.html
It would also be worth looking at the scatter plot documentation at http://matplotlib.org/1.3.1/api/pyplot_api.html#matplotlib.pyplot.scatter
from numpy.random import rand, randint
from matplotlib import pyplot as plt
n = 1000
# Generate random data
data = rand(n, 2)
# Make a random array to mimic datingLabels
labels = randint(1, 4, n)
# Separate the data according to the labels
data_1 = data[labels==1]
data_2 = data[labels==2]
data_3 = data[labels==3]
# Plot each set of points separately
# 's' is the size parameter.
# 'c' is the color parameter.
# I have chosen the colors so that they match the plot shown.
# With each set of points, input the desired label for the legend.
plt.scatter(data_1[:,0], data_1[:,1], s=15, c='r', label="label 1")
plt.scatter(data_2[:,0], data_2[:,1], s=30, c='g', label="label 2")
plt.scatter(data_3[:,0], data_3[:,1], s=45, c='b', label="label 3")
# Put labels on the axes
plt.ylabel("ylabel")
plt.xlabel("xlabel")
# Place the Legend in the plot.
plt.gca().legend(loc="upper left")
# Display it.
plt.show()
The gray borders should become white if you use plt.savefig to save the figure to file instead of displaying it.
Remember to run plt.clf() or plt.cla() after saving to file to clear the axes so you don't end up replotting the same data on top of itself over and over again.

To create the legend you have to:
give labels to each curve
call the legend() method from the current AxesSubplot object, which can be obtained using plt.gca(), for example.
See the example below:
plt.scatter(datingDataMat[:,0], datingDataMat[:,1],
15.0*np.array(datingLabels), 15.0*np.array(datingLabels),
label='Label for this data')
plt.gca().legend(loc='upper left')

Related

Is there a way to improve the line quality when exporting streamplots from matplotlib?

I am drawing streamplots using matplotlib, and exporting them to a vector format. However, I find the streamlines are exported as a series of separate lines - not joined objects. This has the effect of reducing the quality of the image, and making for an unwieldy file for further manipulation. An example; the following images are of a pdf generated by exportfig and viewed in Acrobat Reader:
This is the entire plot
and this is a zoom of the center.
Interestingly, the length of these short line segments is affected by 'density' - increasing the density decreases the length of the lines. I get the same behavior whether exporting to svg, pdf or eps.
Is there a way to get a streamplot to export streamlines as a single object, preferably as a curved line?
MWE
import matplotlib.pyplot as plt
import numpy as np
square_size = 101
x = np.linspace(-1,1,square_size)
y = np.linspace(-1,1,square_size)
u, v = np.meshgrid(-x,y)
fig, axis = plt.subplots(1, figsize = (4,3))
axis.streamplot(x,y,u,v)
fig.savefig('YourDirHere\\test.pdf')
In the end, it seemed like the best solution was to extract the lines from the streamplot object, and plot them using axis.plot. The lines are stored as individual segments with no clue as to which line they belong, so it is necessary to stitch them together into continuous lines.
Code follows:
import matplotlib.pyplot as plt
import numpy as np
def extract_streamlines(sl):
# empty list for extracted lines, flag
new_lines = []
for line in sl:
#ignore zero length lines
if np.array_equiv(line[0],line[1]):
continue
ap_flag = 1
for new_line in new_lines:
#append the line segment to either start or end of exiting lines, if either the star or end of the segment is close.
if np.allclose(line[0],new_line[-1]):
new_line.append(list(line[1]))
ap_flag = 0
break
elif np.allclose(line[1],new_line[-1]):
new_line.append(list(line[0]))
ap_flag = 0
break
elif np.allclose(line[0],new_line[0]):
new_line.insert(0,list(line[1]))
ap_flag = 0
break
elif np.allclose(line[1],new_line[0]):
new_line.insert(0,list(line[0]))
ap_flag = 0
break
# otherwise start a new line
if ap_flag:
new_lines.append(line.tolist())
return [np.array(line) for line in new_lines]
square_size = 101
x = np.linspace(-1,1,square_size)
y = np.linspace(-1,1,square_size)
u, v = np.meshgrid(-x,y)
fig_stream, axis_stream = plt.subplots(1, figsize = (4,3))
stream = axis_stream.streamplot(x,y,u,v)
np_new_lines = extract_streamlines(stream.lines.get_segments())
fig, axis = plt.subplots(1, figsize = (4,4))
for line in np_new_lines:
axis.plot(line[:,0], line[:,1])
fig.savefig('YourDirHere\\test.pdf')
A quick solution to this issue is to change the default cap styles of those tiny segments drawn by the streamplot function. In order to do this, follow the below steps.
Extract all the segments from the stream plot.
Bundle these segments through LineCollection function.
Set the collection's cap style to round.
Set the collection's zorder value smaller than the stream plot's default 2. If it is higher than the default value, the arrows of the stream plot will be overdrawn by the lines of the new collection.
Add the collection to the figure.
The solution of the example code is presented below.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection # Import LineCollection function.
square_size = 101
x = np.linspace(-1,1,square_size)
y = np.linspace(-1,1,square_size)
u, v = np.meshgrid(-x,y)
fig, axis = plt.subplots(1, figsize = (4,3))
strm = axis.streamplot(x,y,u,v)
# Extract all the segments from streamplot.
strm_seg = strm.lines.get_segments()
# Bundle segments with round capstyle. The `zorder` value should be less than 2 to not
# overlap streamplot's arrows.
lc = LineCollection(strm_seg, zorder=1.9, capstyle='round')
# Add the bundled segment to the subplot.
axis.add_collection(lc)
fig.savefig('streamline.pdf')
Additionally, if you want to have streamlines their line widths changing throughout the graph, you have to extract them and append this information to LineCollection.
strm_lw = strm.lines.get_linewidths()
lc = LineCollection(strm_seg, zorder=1.9, capstyle='round', linewidths=strm_lw)
Sadly, the implementation of a color map is not as straight as the above solution. Therefore, using a color map with above approach will not be very pleasing. You can still automate the coloring process, as shown below.
strm_col = strm.lines.get_color()
lc = LineCollection(strm_seg, zorder=1.9, capstyle='round', color=strm_col)
Lastly, I opened a pull request to change the default capstyle option in the matplotlib repository, it can be seen here. You can apply this commit using below code too. If you prefer to do so, you do not need any tricks explained above.
diff --git a/lib/matplotlib/streamplot.py b/lib/matplotlib/streamplot.py
index 95ce56a512..0229ae107c 100644
--- a/lib/matplotlib/streamplot.py
+++ b/lib/matplotlib/streamplot.py
## -222,7 +222,7 ## def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
arrows.append(p)
lc = mcollections.LineCollection(
- streamlines, transform=transform, **line_kw)
+ streamlines, transform=transform, **line_kw, capstyle='round')
lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width]
lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height]
if use_multicolor_lines:

Python - Interpolation of plots

For my evaluation, I have used gnuplot to plot data from two separate csv files (found in this link: https://drive.google.com/open?id=0B2Iv8dfU4fTUZGV6X1Bvb3c4TWs) with a different number of rows which generates the following graph.
These data seem to have no common timestamp (the first column) in both csv files and yet gnuplot seems to fit the plotting as shown above.
Here is the gnuplot script that I use to generate my plot.
# ###### GNU Plot
set style data lines
set terminal postscript eps enhanced color "Times" 20
set output "output.eps"
set title "Actual vs. Estimated Comparison"
set style line 99 linetype 1 linecolor rgb "#999999" lw 2
#set border 1 back ls 11
set key right top
set key box linestyle 50
set key width -2
set xrange [0:10]
set key spacing 1.2
#set nokey
set grid xtics ytics mytics
#set size 2
#set size ratio 0.4
#show timestamp
set xlabel "Time [Seconds]"
set ylabel "Segments"
set style line 1 lc rgb "#ff0000" lt 1 pi 0 pt 4 lw 4 ps 0
plot "estimated.csv" using ($1):2 with lines title "Estimated", "actual.csv" using ($1):2 with lines title "Actual";
I wanted to interpolate my green line into the grid where my pink line is defined, then compare the two. Here is my initial approach
#!/usr/bin/env python
import sys
import numpy as np
from shapely.geometry import LineString
#-------------------------------------------------------------------------------
def load_data(fname):
return LineString(np.genfromtxt(fname, delimiter = ','))
#-------------------------------------------------------------------------------
lines = list(map(load_data, sys.argv[1:]))
for g in lines[0].intersection(lines[1]):
if g.geom_type != 'Point':
continue
print('%f,%f' % (g.x, g.y))
Then in Gnuplot, one can invoke it directly:
set terminal pngcairo
set output 'fig.png'
set datafile separator comma
set yr [0:700]
set xr [0:10]
set xtics 0,2,10
set ytics 0,100,700
set grid
set xlabel "Time [seconds]"
set ylabel "Segments"
plot \
'estimated.csv' w l lc rgb 'dark-blue' t 'Estimated', \
'actual.csv' w l lc rgb 'green' t 'Actual', \
'<python filter.py estimated.csv actual.csv' w p lc rgb 'red' ps 0.5 pt 7 t ''
which gives us the following plot
I wrote the filtered points to another file (filtered_points.csv found in this link:https://drive.google.com/open?id=0B2Iv8dfU4fTUSHVOMzYySjVzZWc) from this script. However, the filtered points are less than 10% of the actual dataset (which is the ground truth).
Is there any way where we can interpolate the two lines by ignoring the pink high peaks above the green plot using python? Gnuplot doesn't seem to be the best tool for this. If the pink line doesn't touch the green line (i.e. if it is way below the green line), I want to take the values of the closest green line so that it will be a one-to-one correspondence (or very close) with the actual dataset. I want to return the interpolated values for the green line in the pink line grid so that we can compare both lines since they have the same array size.
Getting the same data size in terms of an interpolation is pretty simple by numpy.interp(). For me, this code works:
import numpy as np
import matplotlib.pyplot as plt
names = ['actual.csv','estimated.csv']
#-------------------------------------------------------------------------------
def load_data(fname):
return np.genfromtxt(fname, delimiter = ',')
#-------------------------------------------------------------------------------
data = [load_data(name) for name in names]
actual_data = data[0]
estimated_data = data[1]
interpolated_estimation = np.interp(estimated_data[:,0],actual_data[:,0],actual_data[:,1])
plt.figure()
plt.plot(actual_data[:,0],actual_data[:,1], label='actual')
plt.plot(estimated_data[:,0],estimated_data[:,1], label='estimated')
plt.plot(estimated_data[:,0],interpolated_estimation, label='interpolated')
plt.legend()
plt.show(block=True)
After this interpolation interpolated_estimation has the same size as the x axis of actual_data, as the plot suggests. The slicing is a bit confusing but I tried to use your function and make the plot calls as clear as possible.
To save to a file and plot like suggested I changed the code to:
import numpy as np
import matplotlib.pyplot as plt
names = ['actual.csv','estimated.csv']
#-------------------------------------------------------------------------------
def load_data(fname):
return np.genfromtxt(fname, delimiter = ',')
#-------------------------------------------------------------------------------
data = [load_data(name) for name in names]
actual_data = data[0]
estimated_data = data[1]
interpolated_estimation = np.interp(estimated_data[:,0],actual_data[:,0],actual_data[:,1])
plt.figure()
plt.plot(actual_data[:,0],actual_data[:,1], label='actual')
#plt.plot(estimated_data[:,0],estimated_data[:,1], label='estimated')
plt.plot(estimated_data[:,0],interpolated_estimation, label='interpolated')
np.savetxt('interpolated.csv',
np.vstack((estimated_data[:,0],interpolated_estimation)).T,
delimiter=',', fmt='%10.5f') #saves data to filedata to file
plt.legend()
plt.title('Actual vs. Interpolated')
plt.xlim(0,10)
plt.ylim(0,500)
plt.xlabel('Time [Seconds]')
plt.ylabel('Segments')
plt.grid()
plt.show(block=True)
This produces the following output:

Python - finding pattern in a plot

This graph is generated by the following gnuplot script. The estimated.csv file is found in this link: https://drive.google.com/open?id=0B2Iv8dfU4fTUaGRWMm9jWnBUbzg
# ###### GNU Plot
set style data lines
set terminal postscript eps enhanced color "Times" 20
set output "cubic33_cwd_estimated.eps"
set title "Estimated signal"
set style line 99 linetype 1 linecolor rgb "#999999" lw 2
#set border 1 back ls 11
set key right top
set key box linestyle 50
set key width -2
set xrange [0:10]
set key spacing 1.2
#set nokey
set grid xtics ytics mytics
#set size 2
#set size ratio 0.4
#show timestamp
set xlabel "Time [Seconds]"
set ylabel "Segments"
set style line 1 lc rgb "#ff0000" lt 1 pi 0 pt 4 lw 4 ps 0
# Congestion control send window
plot "estimated.csv" using ($1):2 with lines title "Estimated";
I wanted to find the pattern of the estimated signal of the previous plot something close to the following plot. My ground truth (actual signal is shown in the following plot)
Here is my initial approach
#!/usr/bin/env python
import sys
import numpy as np
from shapely.geometry import LineString
#-------------------------------------------------------------------------------
def load_data(fname):
return LineString(np.genfromtxt(fname, delimiter = ','))
#-------------------------------------------------------------------------------
lines = list(map(load_data, sys.argv[1:]))
for g in lines[0].intersection(lines[1]):
if g.geom_type != 'Point':
continue
print('%f,%f' % (g.x, g.y))
Then invoke this python script in my gnuplot directly as in the following:
set terminal pngcairo
set output 'fig.png'
set datafile separator comma
set yr [0:700]
set xr [0:10]
set xtics 0,2,10
set ytics 0,100,700
set grid
set xlabel "Time [seconds]"
set ylabel "Segments"
plot \
'estimated.csv' w l lc rgb 'dark-blue' t 'Estimated', \
'actual.csv' w l lc rgb 'green' t 'Actual', \
'<python filter.py estimated.csv actual.csv' w p lc rgb 'red' ps 0.5 pt 7 t ''
which gives us the following plot. But this does not seem to give me the right pattern as gnuplot is not the best tool for such tasks.
Is there any way where we can find the pattern of the first graph (estimated.csv) by forming the peaks into a plot using python? If we see from the end, the pattern actually seems to be visible. Any help would be appreciated.
I think pandas.rolling_max() is the right approach here. We are loading the data into a DataFrame and calculate the rolling maximum over 8500 values. Afterwards the curves look similar. You may test with the parameter a little bit to optimize the result.
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.ion()
names = ['actual.csv','estimated.csv']
#-------------------------------------------------------------------------------
def load_data(fname):
return np.genfromtxt(fname, delimiter = ',')
#-------------------------------------------------------------------------------
data = [load_data(name) for name in names]
actual_data = data[0]
estimated_data = data[1]
df = pd.read_csv('estimated.csv', names=('x','y'))
df['rolling_max'] = pd.rolling_max(df['y'],8500)
plt.figure()
plt.plot(actual_data[:,0],actual_data[:,1], label='actual')
plt.plot(estimated_data[:,0],estimated_data[:,1], label='estimated')
plt.plot(df['x'], df['rolling_max'], label = 'rolling')
plt.legend()
plt.title('Actual vs. Interpolated')
plt.xlim(0,10)
plt.ylim(0,500)
plt.xlabel('Time [Seconds]')
plt.ylabel('Segments')
plt.grid()
plt.show(block=True)
To answer the question from the comments:
Since pd.rolling() is generating defined windows of your data, the first values will be NaN for pd.rolling().max. To replace these NaNs, I suggest to turn around the whole Series and to calculate the windows backwards. Afterwards, we can replace all the NaNs by the values from the backwards calculation. I adjusted the window length for the backwards calculation. Otherwise we get erroneous data.
This code works:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.ion()
df = pd.read_csv('estimated.csv', names=('x','y'))
df['rolling_max'] = df['y'].rolling(8500).max()
df['rolling_max_backwards'] = df['y'][::-1].rolling(850).max()
df.rolling_max.fillna(df.rolling_max_backwards, inplace=True)
plt.figure()
plt.plot(df['x'], df['rolling_max'], label = 'rolling')
plt.legend()
plt.title('Actual vs. Interpolated')
plt.xlim(0,10)
plt.ylim(0,700)
plt.xlabel('Time [Seconds]')
plt.ylabel('Segments')
plt.grid()
plt.show(block=True)
And we get the following result:

MatPlotlib Seaborn Multiple Plots formatting

I am translating a set of R visualizations to Python. I have the following target R multiple plot histograms:
Using Matplotlib and Seaborn combination and with the help of a kind StackOverflow member (see the link: Python Seaborn Distplot Y value corresponding to a given X value), I was able to create the following Python plot:
I am satisfied with its appearance, except, I don't know how to put the Header information in the plots. Here is my Python code that creates the Python Charts
""" Program to draw the sampling histogram distributions """
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
def main():
""" Main routine for the sampling histogram program """
sns.set_style('whitegrid')
markers_list = ["s", "o", "*", "^", "+"]
# create the data dataframe as df_orig
df_orig = pd.read_csv('lab_samples.csv')
df_orig = df_orig.loc[df_orig.hra != -9999]
hra_list_unique = df_orig.hra.unique().tolist()
# create and subset df_hra_colors to match the actual hra colors in df_orig
df_hra_colors = pd.read_csv('hra_lookup.csv')
df_hra_colors['hex'] = np.vectorize(rgb_to_hex)(df_hra_colors['red'], df_hra_colors['green'], df_hra_colors['blue'])
df_hra_colors.drop(labels=['red', 'green', 'blue'], axis=1, inplace=True)
df_hra_colors = df_hra_colors.loc[df_hra_colors['hra'].isin(hra_list_unique)]
# hard coding the current_component to pc1 here, we will extend it by looping
# through the list of components
current_component = 'pc1'
num_tests = 5
df_columns = df_orig.columns.tolist()
start_index = 5
for test in range(num_tests):
current_tests_list = df_columns[start_index:(start_index + num_tests)]
# now create the sns distplots for each HRA color and overlay the tests
i = 1
for _, row in df_hra_colors.iterrows():
plt.subplot(3, 3, i)
select_columns = ['hra', current_component] + current_tests_list
df_current_color = df_orig.loc[df_orig['hra'] == row['hra'], select_columns]
y_data = df_current_color.loc[df_current_color[current_component] != -9999, current_component]
axs = sns.distplot(y_data, color=row['hex'],
hist_kws={"ec":"k"},
kde_kws={"color": "k", "lw": 0.5})
data_x, data_y = axs.lines[0].get_data()
axs.text(0.0, 1.0, row['hra'], horizontalalignment="left", fontsize='x-small',
verticalalignment="top", transform=axs.transAxes)
for current_test_index, current_test in enumerate(current_tests_list):
# this_x defines the series of current_component(pc1,pc2,rhob) for this test
# indicated by 1, corresponding R program calls this test_vector
x_series = df_current_color.loc[df_current_color[current_test] == 1, current_component].tolist()
for this_x in x_series:
this_y = np.interp(this_x, data_x, data_y)
axs.plot([this_x], [this_y - current_test_index * 0.05],
markers_list[current_test_index], markersize = 3, color='black')
axs.xaxis.label.set_visible(False)
axs.xaxis.set_tick_params(labelsize=4)
axs.yaxis.set_tick_params(labelsize=4)
i = i + 1
start_index = start_index + num_tests
# plt.show()
pp = PdfPages('plots.pdf')
pp.savefig()
pp.close()
def rgb_to_hex(red, green, blue):
"""Return color as #rrggbb for the given color values."""
return '#%02x%02x%02x' % (red, green, blue)
if __name__ == "__main__":
main()
The Pandas code works fine and it is doing what it is supposed to. It is my lack of knowledge and experience of using 'PdfPages' in Matplotlib that is the bottleneck. How can I show the header information in Python/Matplotlib/Seaborn that I can show in the corresponding R visalization. By the Header information, I mean What The R visualization has at the top before the histograms, i.e., 'pc1', MRP, XRD,....
I can get their values easily from my program, e.g., current_component is 'pc1', etc. But I don't know how to format the plots with the Header. Can someone provide some guidance?
You may be looking for a figure title or super title, fig.suptitle:
fig.suptitle('this is the figure title', fontsize=12)
In your case you can easily get the figure with plt.gcf(), so try
plt.gcf().suptitle("pc1")
The rest of the information in the header would be called a legend.
For the following let's suppose all subplots have the same markers. It would then suffice to create a legend for one of the subplots.
To create legend labels, you can put the labelargument to the plot, i.e.
axs.plot( ... , label="MRP")
When later calling axs.legend() a legend will automatically be generated with the respective labels. Ways to position the legend are detailed e.g. in this answer.
Here, you may want to place the legend in terms of figure coordinates, i.e.
ax.legend(loc="lower center",bbox_to_anchor=(0.5,0.8),bbox_transform=plt.gcf().transFigure)

Python: Legend has wrong colors on Pandas MultiIndex plot

I'm trying to plot data from 2 seperate MultiIndex, with the same data as levels in each.
Currently, this is generating two seperate plots and I'm unable to customise the legend by appending some string to individualise each line on the graph. Any help would be appreciated!
Here is the method so far:
def plot_lead_trail_res(df_ante, df_post, symbols=[]):
if len(symbols) < 1:
print "Try again with a symbol list. (Time constraints)"
else:
df_ante = df_ante.loc[symbols]
df_post = df_post.loc[symbols]
ante_leg = [str(x)+'_ex-ante' for x in df_ante.index.levels[0]]
post_leg = [str(x)+'_ex-post' for x in df_post.index.levels[0]]
print "ante_leg", ante_leg
ax = df_ante.unstack(0).plot(x='SHIFT', y='MUTUAL_INFORMATION', legend=ante_leg)
ax = df_post.unstack(0).plot(x='SHIFT', y='MUTUAL_INFORMATION', legend=post_leg)
ax.set_xlabel('Time-shift of sentiment data (days) with financial data')
ax.set_ylabel('Mutual Information')
Using this function call:
sentisignal.plot_lead_trail_res(data_nasdaq_top_100_preprocessed_mi_res, data_nasdaq_top_100_preprocessed_mi_res_validate, ['AAL', 'AAPL'])
I obtain the following figure:
Current plots
Ideally, both sets of lines would be on the same graph with the same axes!
Update 2 [Concatenation Solution]
I've solved the issues of plotting from multiple frames using concatenation, however the legend does not match the line colors on the graph.
There are not specific calls to legend and the label parameter in plot() has not been used.
Code:
df_ante = data_nasdaq_top_100_preprocessed_mi_res
df_post = data_nasdaq_top_100_preprocessed_mi_res_validate
symbols = ['AAL', 'AAPL']
df_ante = df_ante.loc[symbols]
df_post = df_post.loc[symbols]
df_ante.index.set_levels([[str(x)+'_ex-ante' for x in df_ante.index.levels[0]],df_ante.index.levels[1]], inplace=True)
df_post.index.set_levels([[str(x)+'_ex-post' for x in df_post.index.levels[0]],df_post.index.levels[1]], inplace=True)
df_merge = pd.concat([df_ante, df_post])
df_merge['SHIFT'] = abs(df_merge['SHIFT'])
df_merge.unstack(0).plot(x='SHIFT', y='MUTUAL_INFORMATION')
Image:
MultiIndex Plot Image
I think, with
ax = df_ante.unstack(0).plot(x='SHIFT', y='MUTUAL_INFORMATION', legend=ante_leg)
you put the output of the plot() in ax, including the lines, which then get overwritten by the second function call. Am I right, that the lines which were plotted first are missing?
The official procedure would be rather something like
fig = plt.figure(figsize=(5, 5)) # size in inch
ax = fig.add_subplot(111) # if you want only one axes
now you have an axes object in ax, and can take this as input for the next plots.

Categories