Skipping certain values in Python with Matplotlib - python

I am currently working on an intra-day stock chart using the Alpha Vantage API. The data frame contains values from 4:00 to 20:00. In my matplotlib.pyplot chart however, the x-Axis also includes values from 20:00 to 4:00 over night. I dont want this as it messes up the aesthetics and also the Volume subplot.
Q: Is there any way to skip x-Axis values which dont exist in the actual Data Frame (the values from 20:00 to 04:00)?
As you can see, the Data Frame clearly jumps from 20:00 to 04:00
However in the Matplotlib chart, the x-Axis contains the values from 20:00 to 4:00, messing with the chart
Code so far. I believe so far everything is right:
import pandas as pd
import matplotlib.pyplot as plt
from alpha_vantage.timeseries import TimeSeries
import time
import datetime as dt
from datetime import timedelta as td
from dateutil.relativedelta import relativedelta
#Accessing and Preparing API
ts = TimeSeries(key=api_key, output_format='pandas')
ticker_input = "TSLA"
interval_input = "15min"
df, meta_data = ts.get_intraday(symbol = ticker_input, interval = interval_input, outputsize = 'full')
slice_date = 16*4*5
df = df[0:slice_date]
df = df.iloc[::-1]
df["100ma"] = df["4. close"].rolling(window = 50, min_periods = 0).mean()
df["Close"] = df["4. close"]
df["Date"] = df.index
#Plotting all as 2 different subplots
ax1 = plt.subplot2grid((7,1), (0,0), rowspan = 5, colspan = 1)
ax1.plot(df["Date"], df['Close'])
ax1.plot(df["Date"], df["100ma"], linewidth = 0.5)
plt.xticks(rotation=45)
ax2 = plt.subplot2grid((6,1), (5,0), rowspan = 2, colspan = 2, sharex = ax1)
ax2.bar(df["Date"], df["5. volume"])
ax2.axes.xaxis.set_visible(False)
plt.tight_layout()
plt.show()
It would be great if anybody could help. Im still a complete beginner and only started Python 2 weeks ago.

We got the data from the same place, although the data acquisition method is different. After extracting it in 15 units, I created a graph by excluding the data after 8pm and before 4pm. I created the code with the understanding that your skip would open up the pause. What you want it to skip is skipped once the NaN is set.
import datetime
import pandas as pd
import numpy as np
import pandas_datareader.data as web
import mplfinance as mpf
# import matplotlib.pyplot as plt
with open('./alpha_vantage_api_key.txt') as f:
api_key = f.read()
now_ = datetime.datetime.today()
start = datetime.datetime(2019, 1, 1)
end = datetime.datetime(now_.year, now_.month, now_.day - 1)
symbol = 'TSLA'
df = web.DataReader(symbol, 'av-intraday', start, end, api_key=api_key)
df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
df.index = pd.to_datetime(df.index)
df["100ma"] = df["Close"].rolling(window = 50, min_periods = 0).mean()
df["Date"] = df.index
df_15 = df.asfreq('15min')
df_15 = df_15[(df_15.index.hour >= 4)&(df_15.index.hour <= 20) ]
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8,4.5),dpi=144)
#Plotting all as 2 different subplots
ax1 = plt.subplot2grid((7,1), (0,0), rowspan = 5, colspan = 1)
ax1.plot(df_15["Date"], df_15['Close'])
ax1.plot(df_15["Date"], df_15["100ma"], linewidth = 0.5)
plt.xticks(rotation=20)
ax2 = plt.subplot2grid((6,1), (5,0), rowspan = 2, colspan = 2, sharex = ax1)
ax2.bar(df_15["Date"], df_15["Volume"])
ax2.axes.xaxis.set_visible(False)
# plt.tight_layout()
plt.show()

I fixed it using matplotlib.ticker.formatter.
I first created a class and using:
class MyFormatter(Formatter):
def __init__(self, dates, fmt='%Y-%m-%d %H:%M'):
self.dates = dates
self.fmt = fmt
def __call__(self, x, pos=0):
'Return the label for time x at position pos'
ind = int(np.round(x))
if ind >= len(self.dates) or ind < 0:
return ''
return self.dates[ind].strftime(self.fmt)
formatter = MyFormatter(df.index)
ax1 = plt.subplot2grid((7,1), (0,0), rowspan = 5, colspan = 1)
ax1.xaxis.set_major_formatter(formatter)
ax1.plot(np.arange(len(df)), df["Close"])
ax1.plot(np.arange(len(df)), df["100ma"], linewidth = 0.5)
ax1.xticks(rotation=45)
ax1.axis([xmin,xmax,ymin,ymax])
ax2 = plt.subplot2grid((6,1), (5,0), rowspan = 2, colspan = 2, sharex = ax1)
ax2.bar(np.arange(len(df)), df["5. volume"])
plt.show()
This gave me a smoother graph than the one before and also that recommended by r-beginner.
The only issue that I have is that if I zoom in the x-axis doesnt really change. it always has teh year, month, date, hour, and minute. Obviously I only want hour and minute when Im zoomed in further. I am yet to figure out how to do that

Related

mplfinance stacked plots with common, time-aligned shared axis

I am plotting stacked graphs of 5, 15, 30 and 60 minute candles on top of one another.
I would like to please:
Have all charts with price right aligned (y_on_right=True seems not be used)
For the times/grid on the 5 minute graph to be every 60 mins on the hour
For all the other graphs to use the same as the above, every 60 mins, all aligned
Optionally if possible, to remove the space on the left and the right (so the first bar is up against the left edge, and last bar up against the right edge)
This is my output so far:
And code is below:
import mplfinance as mpf
import pandas as pd
from polygon import RESTClient
def main():
key = "key"
with RESTClient(key) as client:
start = "2019-02-01"
end = "2019-02-02"
ticker = "TVIX"
resp5 = client.stocks_equities_aggregates(ticker, 5, "minute", start, end, unadjusted=False)
resp15 = client.stocks_equities_aggregates(ticker, 15, "minute", start, end, unadjusted=False)
resp30 = client.stocks_equities_aggregates(ticker, 30, "minute", start, end, unadjusted=False)
resp60 = client.stocks_equities_aggregates(ticker, 60, "minute", start, end, unadjusted=False)
print(f'5 min data is {len(resp5.results)} long')
print(f'15 min data is {len(resp15.results)} long')
print(f'30 min data is {len(resp30.results)} long')
print(f'60 min data is {len(resp60.results)} long')
df5 = pd.DataFrame(resp5.results)
df5.index = pd.DatetimeIndex( pd.to_datetime(df5['t']/1000, unit='s') )
df15 = pd.DataFrame(resp15.results)
df15.index = pd.DatetimeIndex( pd.to_datetime(df15['t']/1000, unit='s') )
df30 = pd.DataFrame(resp30.results)
df30.index = pd.DatetimeIndex( pd.to_datetime(df30['t']/1000, unit='s') )
df60 = pd.DataFrame(resp60.results)
df60.index = pd.DatetimeIndex( pd.to_datetime(df60['t']/1000, unit='s') )
df60.index.name = df30.index.name = df15.index.name = df5.index.name = 'Timestamp'
# mpf expects a dataframe containing Open, High, Low, and Close data with a Pandas TimetimeIndex
df60.columns = df30.columns = df15.columns = df5.columns = ['Volume', 'Volume Weighted', 'Open', 'Close', 'High', 'Low', 'Time', 'Num Items']
fig = mpf.figure(figsize=(32, 32))
ax1 = fig.add_subplot(4, 1, 1)
ax2 = fig.add_subplot(4, 1, 2)
ax3 = fig.add_subplot(4, 1, 3)
ax4 = fig.add_subplot(4, 1, 4)
ap = [
mpf.make_addplot(df15, type='candle', ax=ax2, y_on_right=True),
mpf.make_addplot(df30, type='candle', ax=ax3, y_on_right=True),
mpf.make_addplot(df60, type='candle', ax=ax4, y_on_right=True)
]
s = mpf.make_mpf_style(base_mpf_style='default',y_on_right=True)
mpf.plot(df5, style=s, ax=ax1, addplot=ap, xrotation=0, datetime_format='%H:%M', type='candlestick')
if __name__ == '__main__':
main()
I don't have the corresponding API key, so I used Yahoo Finance's stock price instead. As for the issue of placing the price on the right side, you can change the style to achieve this. Also, it seems that y_on_right is only valid for the first graph. From this information. To remove the first and last margins, use tight_layout=True, and to align the x-axis to the hour, you need to check how far the mpl time series formatter can go.
import yfinance as yf
import pandas as pd
import mplfinance as mpf
import numpy as np
import datetime
import matplotlib.dates as mdates
start = '2021-12-22'
end = '2021-12-23'
intervals = [5,15,30,60]
for i in intervals:
vars()[f'df{i}'] = yf.download("AAPL", start=start, end=end, period='1d', interval=str(i)+'m')
for df in [df5,df15,df30,df60]:
df.index = pd.to_datetime(df.index)
df.index = df.index.tz_localize(None)
df5 = df5[df5.index.date == datetime.date(2021,12,21)]
df15 = df15[df15.index.date == datetime.date(2021,12,21)]
df30 = df30[df30.index.date == datetime.date(2021,12,21)]
df60 = df60[df60.index.date == datetime.date(2021,12,21)]
fig = mpf.figure(style='yahoo', figsize=(12,9))
ax1 = fig.add_subplot(4,1,1)
ax2 = fig.add_subplot(4,1,2)
ax3 = fig.add_subplot(4,1,3)
ax4 = fig.add_subplot(4,1,4)
mpf.plot(df5, type='candle', ax=ax1, xrotation=0, datetime_format='%H:%M', tight_layout=True)
mpf.plot(df15, type='candle', ax=ax2, xrotation=0, datetime_format='%H:%M', tight_layout=True)
mpf.plot(df30, type='candle', ax=ax3, xrotation=0, datetime_format='%H:%M', tight_layout=True)
mpf.plot(df60, type='candle', ax=ax4, xrotation=0, datetime_format='%H:%M', tight_layout=True)
ax3_ticks = ax3.get_xticks()
print(ax3_ticks)

Seaborn heatmap change date frequency of yticks

My problem is similar to the one encountered on this topic: Change heatmap's yticks for multi-index dataframe
I would like to have yticks every 6 months, with them being the index of my dataframe. But I can't manage to make it work.
The issue is that my dataframe is 13500*290 and the answer given in the link takes a long time and doesn't really work (see image below).
This is an example of my code without the solution from the link, this part works fine for me:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
df = pd.DataFrame(index = pd.date_range(datetime(1984, 6, 10), datetime(2021, 1, 14), freq='1D') )
for i in range(0,290):
df['Pt{0}'.format(i)] = np.random.random(size=len(df))
f, ax = plt.subplots(figsize=(20,20))
sns.heatmap(df, cmap='PuOr', vmin = np.min(np.min(df)), vmax = np.max(np.max(df)), cbar_kws={"label": "Ice Velocity (m/yr)"})
This part does not work for me and produces the figure below, which shouldn't have the stack of ylabels on the yaxis:
f, ax = plt.subplots(figsize=(20,20))
years = df.index.get_level_values(0)
ytickvalues = [year if index in (2, 7, 12) else '' for index, year in enumerate(years)]
sns.heatmap(df, cmap='PuOr', vmin = np.min(np.min(df)), vmax = np.max(np.max(df)), cbar_kws={"label": "Ice Velocity (m/yr)"}, yticklabels = ytickvalues)
Here are a couple ways to adapt that link for your use case (1 label per 6 months):
Either: Show an empty string except on Jan 1 and Jul 1 (i.e., when %m%d evals to 0101 or 0701)
labels = [date if date.strftime('%m%d') in ['0101', '0701'] else ''
for date in df.index.date]
Or: Show an empty string except every ~365/2 days (i.e., when row % 183 == 0)
labels = [date if row % 183 == 0 else ''
for row, date in enumerate(df.index.date)]
Note that you don't have a MultiIndex, so you can just use df.index.date (no need for get_level_values).
Here is the output with a minimized version of your df:
sns.heatmap(df, cmap='PuOr', cbar_kws={'label': 'Ice Velocity (m/yr)'},
vmin=df.values.min(), vmax=df.values.max(),
yticklabels=labels)

Cannot prepare proper labels in Matplotlib

I have very simple code:
from matplotlib import dates
import matplotlib.ticker as ticker
my_plot=df_h.boxplot(by='Day',figsize=(12,5), showfliers=False, rot=90)
I've got:
but I would like to have fewer labels on X axis. To do this I've add:
my_plot.xaxis.set_major_locator(ticker.MaxNLocator(12))
It generates fewer labels but values of labels have wrong values (=first of few labels from whole list)
What am I doing wrong?
I have add additional information:
I've forgoten to show what is inside DataFrame.
I have three columns:
reg_Date - datetime64 (index)
temperature - float64
Day - date converted from reg_Date to string, it looks like '2017-10' (YYYY-MM)
Box plot group date by 'Day' and I would like to show values 'Day" as a label but not all values
, for example every third one.
You were almost there. Just set ticker.MultipleLocator.
The pandas.DataFrame.boxplot also returns axes, which is an object of class matplotlib.axes.Axes. So you can use this code snippet to customize your labels:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
center = np.random.randint(50,size=(10, 20))
spread = np.random.rand(10, 20) * 30
flier_high = np.random.rand(10, 20) * 30 + 30
flier_low = np.random.rand(10, 20) * -30
y = np.concatenate((spread, center, flier_high, flier_low))
fig, ax = plt.subplots(figsize=(10, 5))
ax.boxplot(y)
x = ['Label '+str(i) for i in range(20)]
ax.set_xticklabels(x)
ax.set_xlabel('Day')
# Set a tick on each integer multiple of a base within the view interval.
ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
plt.xticks(rotation=90)
I think there is a compatibility issue with Pandas plots and Matplotlib formatters.
With the following code:
df = pd.read_csv('lt_stream-1001-full.csv', header=0, encoding='utf8')
df['reg_date'] = pd.to_datetime(df['reg_date'] , format='%Y-%m-%d %H:%M:%S')
df.set_index('reg_date', inplace=True)
df_h = df.resample(rule='H').mean()
df_h['Day']=df_h.index.strftime('%Y-%m')
print(df_h)
f, ax = plt.subplots()
my_plot = df_h.boxplot(by='Day',figsize=(12,5), showfliers=False, rot=90, ax=ax)
locs, labels = plt.xticks()
i = 0
new_labels = list()
for l in labels:
if i % 3 == 0:
label = labels[i]
i += 1
new_labels.append(label)
else:
label = ''
i += 1
new_labels.append(label)
ax.set_xticklabels(new_labels)
plt.show()
You get this chart:
But I notice that this is grouped by month instead of by day. It may not be what you wanted.
Adding the day component to the string 'Day' messes up the chart as there seems to be too many boxes.
df = pd.read_csv('lt_stream-1001-full.csv', header=0, encoding='utf8')
df['reg_date'] = pd.to_datetime(df['reg_date'] , format='%Y-%m-%d %H:%M:%S')
df.set_index('reg_date', inplace=True)
df_h = df.resample(rule='H').mean()
df_h['Day']=df_h.index.strftime('%Y-%m-%d')
print(df_h)
f, ax = plt.subplots()
my_plot = df_h.boxplot(by='Day',figsize=(12,5), showfliers=False, rot=90, ax=ax)
locs, labels = plt.xticks()
i = 0
new_labels = list()
for l in labels:
if i % 15 == 0:
label = labels[i]
i += 1
new_labels.append(label)
else:
label = ''
i += 1
new_labels.append(label)
ax.set_xticklabels(new_labels)
plt.show()
The for loop creates the tick labels every as many periods as desired. In the first chart they were set every 3 months. In the second one, every 15 days.
If you would like to see less grid lines:
df = pd.read_csv('lt_stream-1001-full.csv', header=0, encoding='utf8')
df['reg_date'] = pd.to_datetime(df['reg_date'] , format='%Y-%m-%d %H:%M:%S')
df.set_index('reg_date', inplace=True)
df_h = df.resample(rule='H').mean()
df_h['Day']=df_h.index.strftime('%Y-%m-%d')
print(df_h)
f, ax = plt.subplots()
my_plot = df_h.boxplot(by='Day',figsize=(12,5), showfliers=False, rot=90, ax=ax)
locs, labels = plt.xticks()
i = 0
new_labels = list()
new_locs = list()
for l in labels:
if i % 3 == 0:
label = labels[i]
loc = locs[i]
i += 1
new_labels.append(label)
new_locs.append(loc)
else:
i += 1
ax.set_xticks(new_locs)
ax.set_xticklabels(new_labels)
ax.grid(axis='y')
plt.show()
I've read about x_compat in Pandas plot in order to apply Matplotlib formatters, but I get an error when trying to apply it. I'll give it another shot later.
Old unsuccesful answer
The tick labels seem to be dates. If they are set as datetime in your dataframe, you can:
months = mdates.MonthLocator(1,4,7,10) #Choose the months you like the most
ax.xaxis.set_major_locator(months)
Otherwise, you can let Matplotlib know they are dates by:
ax.xaxis_date()
Your comment:
I have add additional information:
I've forgoten to show what is inside DataFrame.
I have three columns:
reg_Date - datetime64 (index)
temperature - float64
Day - date converted from reg_Date to string, it looks like '2017-10' *(YYYY-MM) *
Box plot group date by 'Day' and I would like to show values 'Day" as a label but not all values
, for example every third one.
Based on your comment in italic above, I would use reg_Date as the input and the following lines:
days = mdates.DayLocator(interval=3)
daysFmt = mdates.DateFormatter('%Y-%m') #to format display
ax.xaxis.set_major_locator(days)
ax.xaxis.set_major_formatter(daysFmt)
I forgot to mention that you will need to:
import matplotlib.dates as mdates
Does this work?

Set maximum of datapoints per plot

Im using the following code:
import matplotlib.pyplot as pyplot
import pandas as pandas
from datetime import datetime
dataset = pandas.read_csv("HugLog_17.01.11.csv", sep=",", header=0)
print('filter data for SrcAddr')
dataset_filtered = dataset[dataset['SrcAddr']=='0x1FD3']
print('get Values')
varY = dataset_filtered.Battery_Millivolt.values
varX = dataset_filtered.Timestamp.values
print('Convert the date-strings in date-objects.')
dates_list = [datetime.strptime(date, '%y-%m-%d %H:%M:%S') for date in varX]
fig = pyplot.figure()
ax1 = fig.add_subplot(1,1,1)
ax1.set_xlabel('Time')
ax1.set_ylabel('Millivolt')
ax1.bar(dates_list, varY)
pyplot.locator_params(axis='x',nbins=10)
pyplot.show()
The problem i have is, its a large datacollection with 180k datapoints.
And pyplot displays all points an the graph which makes it slow and the bars overlap. Is there a way to set a maximum-limit on how much datapoints a displayed at a "view".
What i mean by that is, that as soon as the graph is render ther are only 50 datapoints and when i zoomm in i only get a maximum of 50 datapoints again.
Resampling can be done with the resample function from pandas.
Note that the resample syntax has changed between version 0.17 and 0.19 of pandas. The example below uses the old style. See e.g. this tutorial for the new style.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# generate some data for every second over a whole day
times = pd.date_range(start='2017-01-11',periods=86400, freq='1S')
df = pd.DataFrame(index = times)
df['data'] = np.sort(np.random.randint(low=1300, high=1600, size=len(df.index)) )[::-1] + \
np.random.rand(len(df.index))*100
# resample the data, taking the mean over 1 hours ("H")
t = "H" # for hours, try "T" for minutes as well
width=1./24 #matplotlib default uses a width of 1 day per bar
# try width=1./(24*60) for minutes
df_resampled = pd.DataFrame()
df_resampled['data'] = df.data.resample(t, how="mean")
fig, ax = plt.subplots()
#ax.bar(df.index, df['data'], width=1./(24*60*60)) # original data, takes too long to plot
ax.bar(df_resampled.index, df_resampled['data'], width=width)
ax.xaxis_date()
plt.show()
Automatic adaption of the resampling when zooming would indeed require some manual work. There is a resampling example on the matplotlib event handling page, which does not work out of the box but could be adapted accordingly.
This is how it would look like:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
import matplotlib.dates
class Sampler():
def __init__(self,df):
self.df = df
def resample(self, limits):
print limits
dt = limits[1] - limits[0]
if (type(dt) != pd.tslib.Timedelta) and (type(dt) != datetime.timedelta):
dt = datetime.timedelta(days=dt)
print dt
#see #http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
if dt > datetime.timedelta(hours=5):
t = "H"; width=1./24
elif dt > datetime.timedelta(minutes=60):
t = "15T"; width=15./(24.*60)
elif dt > datetime.timedelta(minutes=5):
t = "T"; width=1./(24.*60)
elif dt > datetime.timedelta(seconds=60):
t = "15S"; width=15./(24.*60*60)
else:
#dt < datetime.timedelta(seconds=60):
t = "S"; width=1./(24.*60*60)
self.resampled = pd.DataFrame()
self.resampled['data'] = self.df.data.resample(t, how="mean")
print t, len(self.resampled['data'])
print "indextype", type(self.resampled.index[0])
print "limitstype", type(limits[1])
if type(limits[1]) == float or type(limits[1]) == np.float64 :
dlowlimit = matplotlib.dates.num2date(limits[0])
duplimit = matplotlib.dates.num2date(limits[1])
print type(duplimit), duplimit
self.resampled = self.resampled.loc[self.resampled.index <= duplimit]
self.resampled = self.resampled.loc[self.resampled.index >= dlowlimit]
else:
self.resampled = self.resampled.loc[self.resampled.index <= limits[1]]
self.resampled = self.resampled.loc[self.resampled.index >= limits[0]]
return self.resampled.index,self.resampled['data'],width
def update(self, ax):
print "update"
lims = ax.viewLim
start, stop = lims.intervalx
ax.clear()
x,y,width = self.resample([start, stop])
ax.bar(x,y, width=width)
ax.set_xlim([start, stop])
ax.callbacks.connect('xlim_changed', self.update)
ax.figure.canvas.draw()
times = pd.date_range(start='2017-01-11',periods=86400, freq='1S')
df = pd.DataFrame(index = times)
df['data'] = np.sort(np.random.randint(low=1300, high=1600, size=len(df.index)) )[::-1] + \
np.random.rand(len(df.index))*500
sampler = Sampler(df)
x,y,width = sampler.resample( [df.index[0],df.index[-1] ] )
fig, ax = plt.subplots()
ax.bar(x,y, width=width)
ax.xaxis_date()
# connect to limits changes
ax.callbacks.connect('xlim_changed', sampler.update)
plt.show()
One thing you can do is plot a random subset of the data by using the sample method on your pandas DataFrame. Use the frac argument to determine the fraction of points you want to use. It ranges from 0 to 1.
After you get your dataset_filtered DataFrame, take a sample of it like this
dataset_filtered_sample = dataset_filtered.sample(frac=.001)

Create gantt chart with hlines?

I've tried for several hours to make this work. I tried using 'python-gantt' package, without luck. I also tried plotly (which was beautiful, but I can't host my sensitive data on their site, so that won't work).
My starting point is code from here:
How to plot stacked event duration (Gantt Charts) using Python Pandas?
Three Requirements:
Include the 'Name' on the y axis rather than the numbers.
If someone has multiple events, put all the event periods on one line (this will make pattern identification easier), e.g. Lisa will only have one line on the visual.
Include the 'Event' listed on top of the corresponding line (if possible), e.g. Lisa's first line would say "Hire".
The code will need to be dynamic to accommodate many more people and more possible event types...
I'm open to suggestions to visualize: I want to show the duration for various staffing events throughout the year, as to help identify patterns.
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as dt
df = pd.DataFrame({'Name': ['Joe','Joe','Lisa','Lisa','Lisa','Alice'],
'Event': ['Hire','Term','Hire','Transfer','Term','Term'],
'Start_Date': ["2014-01-01","2014-02-01","2015-01-01","2015-02-01","2015-03-01","2016-01-01"],
'End_Date': ["2014-01-31","2014-03-15","2015-01-31","2015-02-28","2015-05-01","2016-09-01"]
})
df = df[['Name','Event','Start_Date','End_Date']]
df.Start_Date = pd.to_datetime(df.Start_Date).astype(datetime)
df.End_Date = pd.to_datetime(df.End_Date).astype(datetime)
fig = plt.figure()
ax = fig.add_subplot(111)
ax = ax.xaxis_date()
ax = plt.hlines(df.index, dt.date2num(df.Start_Date), dt.date2num(df.End_Date))
I encountered the same problem in the past. You seem to appreciate the aesthetics of Plotly. Here is a little piece of code which uses matplotlib.pyplot.broken_barh instead of matplotlib.pyplot.hlines.
from collections import defaultdict
from datetime import datetime
from datetime import date
import pandas as pd
import matplotlib.dates as mdates
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
df = pd.DataFrame({
'Name': ['Joe', 'Joe', 'Lisa', 'Lisa', 'Lisa', 'Alice'],
'Event': ['Hire', 'Term', 'Hire', 'Transfer', 'Term', 'Term'],
'Start_Date': ['2014-01-01', '2014-02-01', '2015-01-01', '2015-02-01', '2015-03-01', '2016-01-01'],
'End_Date': ['2014-01-31', '2014-03-15', '2015-01-31', '2015-02-28', '2015-05-01', '2016-09-01']
})
df = df[['Name', 'Event', 'Start_Date', 'End_Date']]
df.Start_Date = pd.to_datetime(df.Start_Date).astype(datetime)
df.End_Date = pd.to_datetime(df.End_Date).astype(datetime)
names = df.Name.unique()
nb_names = len(names)
fig = plt.figure()
ax = fig.add_subplot(111)
bar_width = 0.8
default_color = 'blue'
colors_dict = defaultdict(lambda: default_color, Hire='green', Term='red', Transfer='orange')
# Plot the events
for index, name in enumerate(names):
mask = df.Name == name
start_dates = mdates.date2num(df.loc[mask].Start_Date)
end_dates = mdates.date2num(df.loc[mask].End_Date)
durations = end_dates - start_dates
xranges = zip(start_dates, durations)
ymin = index - bar_width / 2.0
ywidth = bar_width
yrange = (ymin, ywidth)
facecolors = [colors_dict[event] for event in df.loc[mask].Event]
ax.broken_barh(xranges, yrange, facecolors=facecolors, alpha=1.0)
# you can set alpha to 0.6 to check if there are some overlaps
# Shrink the x-axis
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
# Add the legend
patches = [mpatches.Patch(color=color, label=key) for (key, color) in colors_dict.items()]
patches = patches + [mpatches.Patch(color=default_color, label='Other')]
plt.legend(handles=patches, bbox_to_anchor=(1, 0.5), loc='center left')
# Format the x-ticks
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
ax.xaxis.set_minor_locator(mdates.MonthLocator())
# Format the y-ticks
ax.set_yticks(range(nb_names))
ax.set_yticklabels(names)
# Set the limits
date_min = date(df.Start_Date.min().year, 1, 1)
date_max = date(df.End_Date.max().year + 1, 1, 1)
ax.set_xlim(date_min, date_max)
# Format the coords message box
ax.format_xdata = mdates.DateFormatter('%Y-%m-%d')
# Set the title
ax.set_title('Gantt Chart')
plt.show()
I hope this will help you.

Categories