I am working on two datasets on churn classification, my problem is as you can see below on the two graph the y-axis are not on the same scale. Bank stops at 0.8 and telco-europa at 1, I would like to force the y-axis to always display 0, 0.2, 0.4, 0.6, 0.8, 1.
I have used the following code:
and my histogram is based on this tutorial: https://www.kaggle.com/pavanraj159/telecom-customer-churn-prediction and the bank dataset is this one https://www.kaggle.com/shrutimechlearn/churn-modelling
import plotly.graph_objs as go#visualization
import plotly.offline as py#visualization
def output_tracer(metric,color, model_performances) :
tracer = go.Bar(x = model_performances["Algorithm"] ,
y = model_performances[metric],
orientation = "v",name = metric ,
marker = dict(line = dict(width =.7),
color = color)
)
return tracer
def output_data(model_performances):
trace1 = output_tracer("1-Precision","#6699FF", model_performances)
trace2 = output_tracer('1-Recall',"red", model_performances)
trace3 = output_tracer('1-F1-score',"#33CC99", model_performances)
trace4 = output_tracer('Accuracy',"lightgrey", model_performances)
trace5 = output_tracer('AUC',"#FFCC99", model_performances)
data = [trace1,trace2,trace3,trace4,trace5]
return data
def output_layout(model):
layout = go.Layout(dict(title = model,
plot_bgcolor = "rgb(243,243,243)",
paper_bgcolor = "rgb(243,243,243)",
xaxis = dict(gridcolor = 'rgb(255, 255, 255)',
title = "",
zerolinewidth=1,
ticklen=5,gridwidth=2),
yaxis = dict(gridcolor = 'rgb(255, 255, 255)',
zerolinewidth=1,ticklen=5,gridwidth=2),
margin = dict(l = 250),
height = 400
)
)
return layout
model = "Bank"
model_performances = report_df_scoring[report_df_scoring.Dataset == model]
fig = go.Figure(data=output_data(model_performances),layout=output_layout(model))
py.iplot(fig)
And here you can fin the dataframe as a dictionary "report_df_scoring" for only the "Bank" dataset
{'Dataset': {0: 'Bank',
1: 'Bank',
2: 'Bank',
3: 'Bank',
4: 'Bank',
5: 'Bank',
6: 'Bank'},
'Algorithm': {0: 'LogisticRegressionNoSMOTE',
1: 'Logistic Regression',
2: 'SVM-linear',
3: 'SVM-rbf',
4: 'xgboost',
5: 'GaussianNB',
6: 'RandomForest'},
'W-Precision': {0: 0.8159638339642141,
1: 0.8229500536388679,
2: 0.8243426658647828,
3: 0.7956512785333915,
4: 0.8288351219512194,
5: 0.8302513223140496,
6: 0.8307514249037228},
'W-Recall': {0: 0.8324,
1: 0.7636,
2: 0.7628,
3: 0.8056,
4: 0.836,
5: 0.8176,
6: 0.8408},
'W-F1-score': {0: 0.810103868755423,
1: 0.7811452562742854,
2: 0.7807117770916884,
3: 0.7997335148514852,
4: 0.831622605929424,
5: 0.7598757585104978,
6: 0.8336474053248425},
'0-Precision': {0: 0.8493518104604381,
1: 0.9187236604455148,
2: 0.9206541490006056,
3: 0.8634596695821186,
4: 0.8834146341463415,
5: 0.8152892561983471,
6: 0.8789473684210526},
'0-Recall': {0: 0.958627648839556,
1: 0.7699293642785066,
2: 0.7669021190716448,
3: 0.8965691220988901,
4: 0.9137235116044399,
5: 0.9954591321897074,
6: 0.9268415741675076},
'0-F1-score': {0: 0.9006873666745674,
1: 0.8377710678012626,
2: 0.8367740159647675,
3: 0.8797029702970298,
4: 0.8983134920634921,
5: 0.8964107223989097,
6: 0.9022593320235756},
'1-Precision': {0: 0.6882129277566539,
1: 0.4564958283671037,
2: 0.4558303886925795,
3: 0.5361990950226244,
4: 0.62,
5: 0.8875,
6: 0.6463414634146342},
'1-Recall': {0: 0.34942084942084944,
1: 0.7393822393822393,
2: 0.747104247104247,
3: 0.4575289575289575,
4: 0.5386100386100386,
5: 0.13706563706563707,
6: 0.5115830115830116},
'1-F1-score': {0: 0.4635083226632522,
1: 0.5644804716285925,
2: 0.5662033650329188,
3: 0.49375,
4: 0.5764462809917356,
5: 0.2374581939799331,
6: 0.5711206896551725},
'Accuracy': {0: 0.8324,
1: 0.7636,
2: 0.7628,
3: 0.8056,
4: 0.836,
5: 0.8176,
6: 0.8408},
'AUC': {0: 0.6540242491302027,
1: 0.754655801830373,
2: 0.7570031830879459,
3: 0.6770490398139237,
4: 0.7261667751072393,
5: 0.5662623846276723,
6: 0.7192122928752596},
'SMOTE': {0: 'No',
1: 'Yes',
2: 'Yes',
3: 'Yes',
4: 'Yes',
5: 'Yes',
6: 'Yes'},
'top3var': {0: "['numofproducts_4', 'numofproducts_3', 'geography_germany']",
1: "['numofproducts_4', 'numofproducts_3', 'geography_germany']",
2: "['numofproducts_4', 'numofproducts_3', 'age']",
3: "['empty']",
4: "['numofproducts_2', 'numofproducts_1', 'isactivemember']",
5: "['empty']",
6: "['age', 'numofproducts_2', 'balance']"}}
You can access and edit the range of any axis of your figure using:
fig['layout']['yaxis']['range']
And set the range like:
fig['layout']['yaxis']['range'] = [0, 1]
The same thing goes for your tickvals:
fig['layout']['yaxis']['tickvals'] = [0, 0.2, 0.4, 0.6, 0.8, 1]
You can use:
fig.update_yaxes(tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1])
Your code example does not work for me because "report_df_scoring" is missing.
Related
I am using matplotlib for plotting and convenient visualization of some graphs in xy coordinates.
I need to highlight some regions - and I use rectangles for this.
But I am interested to add some text upon each rectangle - to be able to distinguish those regions. How to do it using patches because I have a lot of objects in a plot?
Here is the code I use to plot rectangles:
# sample data for rectangles visualization
windows_df = pd.DataFrame( {'window_index_num': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}, 'left_pulse_num': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}, 'right_pulse_num': {0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 8, 7: 9, 8: 10, 9: 11}, 'idx_of_left_pulse': {0: 0, 1: 4036, 2: 4080, 3: 4107, 4: 4368, 5: 4491, 6: 4529, 7: 4624, 8: 4626, 9: 4639}, 'idx_of_right_pulse': {0: 4080, 1: 4107, 2: 4368, 3: 4491, 4: 4529, 5: 4624, 6: 4626, 7: 4639, 8: 4679, 9: 4781}, 'left_pulse_pos_in_E': {0: 10.002042118364418, 1: 40.29395464818188, 2: 41.19356816747343, 3: 41.76060061888303, 4: 47.90221207147802, 5: 51.27679395217831, 6: 52.39165780468267, 7: 55.37561818764979, 8: 55.47294132608167, 9: 55.99635666692289}, 'right_pulse_pos_in_E': {0: 41.19356816747343, 1: 41.76060061888303, 2: 47.90221207147802, 3: 51.27679395217831, 4: 52.39165780468267, 5: 55.37561818764979, 6: 55.47294132608167, 7: 55.99635666692289, 8: 57.33777021469516, 9: 60.984834434908144}, 'idx_window_left_border': {0: 0, 1: 3990, 2: 4058, 3: 4093, 4: 4237, 5: 4429, 6: 4510, 7: 4576, 8: 4625, 9: 4632}, 'idx_window_right_border': {0: 4094, 1: 4238, 2: 4430, 3: 4510, 4: 4577, 5: 4625, 6: 4633, 7: 4659, 8: 4730, 9: 4792}, 'left_win_pos_in_E': {0: 10.002042118364418, 1: 39.38459790393702, 2: 40.74003692229216, 3: 41.46513255508269, 4: 44.66179219947279, 5: 49.53272998148, 6: 51.82972979173252, 7: 53.82159300113625, 8: 55.40803086073492, 9: 55.76645477820397}, 'right_win_pos_in_E': {0: 41.48613320837913, 1: 44.6852679849016, 2: 49.56014983071213, 3: 51.82972979173252, 4: 53.85265044341121, 5: 55.40803086073492, 6: 55.79921126600202, 7: 56.66110947958804, 8: 59.119140585251095, 9: 61.39880967219205}, 'window_width': {0: 4095, 1: 249, 2: 373, 3: 418, 4: 341, 5: 197, 6: 124, 7: 84, 8: 106, 9: 161}, 'window_width_in_E': {0: 31.48409109001471, 1: 5.300670080964579, 2: 8.820112908419965, 3: 10.364597236649828, 4: 9.190858243938415, 5: 5.875300879254915, 6: 3.9694814742695, 7: 2.8395164784517917, 8: 3.7111097245161773, 9: 5.632354893988079}, 'sum_pulses_duration_in_E': {0: 0.5157099691135514, 1: 0.5408987779694527, 2: 0.6869248977656355, 3: 0.7304908951030242, 4: 0.7269657511683718, 5: 0.537271616198268, 6: 0.7609034761658222, 7: 0.6178183490930067, 8: 0.8269277926972265, 9: 0.5591109437337494}, 'sum_pulse_sq': {0: 3.7944375922206044, 1: 3.8756992116858715, 2: 2.9661915477796663, 3: 3.070559830941317, 4: 3.0597037730539385, 5: 10.2020204659669, 6: 45.77535573608872, 7: 45.87630607524008, 8: 39.10335270063814, 9: 3.437205923490125}, 'pulse_to_window_rate': {0: 0.01638001769335214, 1: 0.10204347180781788, 2: 0.07788164447530765, 3: 0.0704794290047244, 4: 0.0790966122938326, 5: 0.09144580460471718, 6: 0.1916883807363909, 7: 0.2175787158769594, 8: 0.22282493757444324, 9: 0.09926770493999569}, 'max_height_in_window': {0: 20.815950580921104, 1: 20.815950580921104, 2: 5.324888970962656, 3: 5.324888970962656, 4: 5.14075603114903, 5: 86.81228155905252, 6: 110.06755904473022, 7: 110.06755904473022, 8: 110.06755904473022, 9: 14.735092268739246}, 'min_height_in_window': {0: -0.011928180619527797, 1: 1.6172637244080776, 2: 1.6172637244080776, 3: 0.8658702248969847, 4: 0.8658702248969847, 5: 0.8658702248969847, 6: 1.8476229914953515, 7: 2.918666252051556, 8: 3.2397786967451707, 9: 2.4893555139463266}, 'windows_sq': {0: 655.3712842149647, 1: 110.33848645112575, 2: 46.96612194869083, 3: 55.19032951390669, 4: 47.24795994896218, 5: 510.0482741740266, 6: 436.911136546121, 7: 312.538647650477, 8: 408.4727887246568, 9: 82.9932690531994}} )
fig_w, axs_w = plt.subplots()
#theoretical cross-section
#axs_w.plot(df_wo_NANS['E'], df_wo_NANS['theo_cs'], marker = "o", markersize = 1, linewidth = 1.0, alpha=0.6, color = 'green', label = 'Theo Cross Section')
axs_w.grid(color = 'grey', linestyle = '--', linewidth = 0.2)
#windows rectangular
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
boxes = []
for index,row in windows_df.iterrows():
current_rect_left_corner = (row['left_win_pos_in_E'], row['min_height_in_window'])
current_w = row['window_width_in_E']
current_h = row['max_height_in_window']-row['min_height_in_window']
boxes.append(Rectangle(current_rect_left_corner, current_w, current_h))
left = row['left_win_pos_in_E']
right = row['right_win_pos_in_E']
bottom = row['min_height_in_window']
top = row['max_height_in_window']
#mark of the start of the current window
axs_w.text(
left, #left corner, #0.5*(left+right), #middle of the rectangle
top, #top
str(index),
horizontalalignment='center',
verticalalignment='center',
fontsize=5
)
#mark of the end of the current window
axs_w.text(
right, #right corner, #0.5*(left+right), #middle of the rectangle
top+0.5*bottom, #top
str(index)+'e',
horizontalalignment='center',
verticalalignment='center',
fontsize=5
)
pc = PatchCollection(boxes, facecolor='y', alpha=0.2, edgecolor='black')
axs_w.add_collection(pc)
Added text marks using cycle but is it possible to do it using patch and collections to make more efficient code?
Here is my code for some NBA project:
fig = make_subplots(
rows = 1, cols = 5,
)
fig.add_trace(
go.Box(x=per_game_player['HOF'], y=per_game_player['trb'], name ='Rebounds'),
row = 1, col = 1
)
fig.add_trace(
go.Box(x=per_game_player['HOF'], y=per_game_player['ast'], name = 'Assists'),
row = 1, col = 2
)
fig.add_trace(
go.Box(x=per_game_player['HOF'], y=per_game_player['stl'], name = 'Steals'),
row = 1, col = 3
)
fig.add_trace(
go.Box(x=per_game_player['HOF'], y=per_game_player['blk'], name = 'Blocks'),
row = 1, col = 4
)
fig.add_trace(
go.Box(x=per_game_player['HOF'], y=per_game_player['pts'], name = 'Points'),
row = 1, col = 5,
)
hovertemp = '<b>%{customdata}: </b> %{y} <br>'
fig.update_traces(
showlegend=False,
customdata = per_game_player['player'],
hovertemplate = hovertemp,
)
fig.update_layout(title_text = 'top stats')
fig.show();
Here's the result:
Do you have any ideas how can I replace '0' and '1' on the horizontal line?
'0' and '1' also appear on hover once I slide through the chart - but not on the outliers.
My table head:
{'player_id': {0: 2218, 1: 3168, 2: 2560, 3: 3228, 4: 4374},
'player': {0: 'A.C. Green',
1: 'A.J. Bramlett',
2: 'A.J. English',
3: 'A.J. Guyton',
4: 'A.J. Hammons'},
'g': {0: 1278.0, 1: 8.0, 2: 151.0, 3: 80.0, 4: 22.0},
'fg': {0: 3.56, 1: 0.5, 2: 4.09, 3: 2.08, 4: 0.77},
'fga': {0: 7.2, 1: 2.62, 2: 9.39, 3: 5.5, 4: 1.91},
'trb': {0: 7.41, 1: 2.75, 2: 2.09, 3: 1.0, 4: 1.64},
'ast': {0: 1.1, 1: 0.0, 2: 2.12, 3: 1.84, 4: 0.18},
'stl': {0: 0.81, 1: 0.12, 2: 0.38, 3: 0.25, 4: 0.05},
'blk': {0: 0.43, 1: 0.0, 2: 0.16, 3: 0.15, 4: 0.59},
'pts': {0: 9.65, 1: 1.0, 2: 9.95, 3: 5.52, 4: 2.18},
'all_star': {0: 1, 1: 0, 2: 0, 3: 0, 4: 0},
'outliers': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0},
'HOF': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}}
Thx.
I'd simply just alter the data in your dataframe.
import numpy as np
per_game_player['HOF'] = np.where(per_game_player['HOF'] == 1, 'HOF', 'Not HOF')
Output:
I found the way to implement the stackplot if my x-axis is just a list of numbers.
import pandas as pd
import matplotlib.pyplt as plt
d = {'time_key': {0: '2021-03-01',
1: '2021-03-01',
2: '2021-03-01',
3: '2021-03-01'},
'target': {0: 2, 1: 1, 2: 0, 3: 3},
'count': {0: 400, 1: 300, 2: 200, 3: 100},
'fraction': {0: 0.4, 1: 0.3, 2: 0.2, 3: 0.1}}
df = pd.DataFrame(d)
plt.stackplot(range(2), s[s.target==0].fraction, s[s.target==1].fraction,
s[s.target==2].fraction, s[s.target==3].fraction)
But I want to generalize the plot to many dates list.
d = {'time_key': {0: '2021-03-01',
1: '2021-03-01',
2: '2021-03-01',
3: '2021-03-01',
4: '2021-04-01',
5: '2021-04-01',
6: '2021-04-01',
7: '2021-04-01',
8: '2021-05-01',
9: '2021-05-01',
10: '2021-05-01',
11: '2021-05-01'},
'target': {0: 2,
1: 1,
2: 0,
3: 3,
4: 2,
5: 1,
6: 0,
7: 3,
8: 2,
9: 1,
10: 0,
11: 3},
'count': {0: 163,
1: 110,
2: 90,
3: 38,
4: 113,
5: 97,
6: 56,
7: 34,
8: 85,
9: 57,
10: 42,
11: 16},
'fraction': {0: 0.18091009988901222,
1: 0.1220865704772475,
2: 0.09988901220865705,
3: 0.042175360710321866,
4: 0.12541620421753608,
5: 0.1076581576026637,
6: 0.06215316315205328,
7: 0.03773584905660377,
8: 0.09433962264150944,
9: 0.06326304106548279,
10: 0.04661487236403995,
11: 0.017758046614872364}}
And I'd like to assign dates to x-axis in ascending order to see dynamics of the proportions.
Is this a way to implement it in a proper way?
The approximate desired output plot (I need time_key x-axis though):
Try:
dfp = df.set_index(['time_key','target'])['count'].unstack()
dfp.div(dfp.sum(axis=1), axis=0).plot.bar(stacked=True)
Output:
Also useful solution is
d = {0: {'2021-03-01': 0.2, '2021-04-01': 0.25, '2021-05-01': 0.3},
1: {'2021-03-01': 0.3, '2021-04-01': 0.25, '2021-05-01': 0.3},
2: {'2021-03-01': 0.4, '2021-04-01': 0.25, '2021-05-01': 0.3},
3: {'2021-03-01': 0.1, '2021-04-01': 0.25, '2021-05-01': 0.1}}
df = pd.DataFrame(d)
fig, ax = plt.subplots(figsize=(9, 6))
plt.style.use('classic')
df.plot.area(ax=ax)
I am trying to use plotly to compare the coefficents of regression models using error bars for the confidence intervals. I used the following code to plot it, using the variable as a categorical y axis in a scatter plot. The problem is that the points are overlapping, and I'd like to dodge them like happens in bar charts when you set barmode='group'. If I had a numerical axis I could manually dodge them, but I can't do that.
fig = px.scatter(
df, y='index', x='coef', text='label', color='model',
error_x_minus='lerr', error_x='uerr',
hover_data=['coef', 'pvalue', 'lower', 'upper']
)
fig.update_traces(textposition='top center')
fig.update_yaxes(autorange="reversed")
Using facets I get almost the result I want, but some of the labels goes off-plot and are not visible:
fig = px.scatter(
df, y='model', x='coef', text='label', color='model',
facet_row='index',
error_x_minus='lerr', error_x='uerr',
hover_data=['coef', 'pvalue', 'lower', 'upper']
)
fig.update_traces(textposition='top center')
fig.update_yaxes(visible=False)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
Somebody has any idea or workaround for either dodging points in the first case or displaying labels in the second case?
Thanks in advance.
PS: Here's the random fake dataframe I made to generate the plots:
df = pd.DataFrame({'coef': {0: 1.0018729737113143,
1: 0.9408864645423858,
2: 0.29796556981484884,
3: -0.6844053575764955,
4: -0.13689631932690113,
5: 0.1473096200402363,
6: 0.9564712505670716,
7: 0.956099003887811,
8: 0.33319108930207175,
9: -0.7022778825729681,
10: -0.1773916842612131,
11: 0.09485417304851751},
'index': {0: 'const',
1: 'x1',
2: 'x2',
3: 'x3',
4: 'x4',
5: 'x5',
6: 'const',
7: 'x1',
8: 'x2',
9: 'x3',
10: 'x4',
11: 'x5'},
'label': {0: '1.002***',
1: '0.941***',
2: '0.298***',
3: '-0.684***',
4: '-0.137',
5: '0.147',
6: '0.956***',
7: '0.956***',
8: '0.333***',
9: '-0.702***',
10: '-0.177',
11: '0.095'},
'lerr': {0: 0.19788416996400904,
1: 0.19972987383410545,
2: 0.0606849959013587,
3: 0.1772734289533593,
4: 0.1988122854078155,
5: 0.21870366703236832,
6: 0.2734783191688098,
7: 0.2760291042678362,
8: 0.08386739920069491,
9: 0.2449940255063039,
10: 0.27476098595116555,
11: 0.3022511162310027},
'lower': {0: 0.8039888037473053,
1: 0.7411565907082803,
2: 0.23728057391349014,
3: -0.8616787865298547,
4: -0.33570860473471664,
5: -0.07139404699213203,
6: 0.6829929313982618,
7: 0.6800698996199748,
8: 0.24932369010137684,
9: -0.947271908079272,
10: -0.45215267021237865,
11: -0.2073969431824852},
'model': {0: 'OLS',
1: 'OLS',
2: 'OLS',
3: 'OLS',
4: 'OLS',
5: 'OLS',
6: 'QuantReg',
7: 'QuantReg',
8: 'QuantReg',
9: 'QuantReg',
10: 'QuantReg',
11: 'QuantReg'},
'pvalue': {0: 1.4211692095019375e-16,
1: 4.3583690618389965e-15,
2: 6.278403727223468e-16,
3: 1.596372747840846e-11,
4: 0.17483151363955116,
5: 0.18433051296752084,
6: 4.877385844808361e-10,
7: 6.665860891682504e-10,
8: 5.476882838731488e-12,
9: 1.4240852942202845e-07,
10: 0.20303143985022934,
11: 0.5347222575215599},
'uerr': {0: 0.19788416996400904,
1: 0.19972987383410556,
2: 0.06068499590135873,
3: 0.1772734289533593,
4: 0.19881228540781554,
5: 0.21870366703236832,
6: 0.27347831916880994,
7: 0.2760291042678362,
8: 0.08386739920069491,
9: 0.2449940255063039,
10: 0.27476098595116555,
11: 0.3022511162310027},
'upper': {0: 1.1997571436753234,
1: 1.1406163383764913,
2: 0.35865056571620757,
3: -0.5071319286231362,
4: 0.0619159660809144,
5: 0.3660132870726046,
6: 1.2299495697358815,
7: 1.2321281081556472,
8: 0.41705848850276667,
9: -0.4572838570666642,
10: 0.09736930168995245,
11: 0.3971052892795202}})
You were very close to a working solution with your second attempt. Just make more room for your labels with:
height=600, width=800
And then place the labels for the traces named 'OLS' within the boundaries of each subplot with:
fig.for_each_trace(lambda t: t.update(textposition='bottom center') if t.name == 'OLS' else ())
Plot:
Complete code:
import plotly.express as px
import pandas as pd
df = pd.DataFrame({'coef': {0: 1.0018729737113143,
1: 0.9408864645423858,
2: 0.29796556981484884,
3: -0.6844053575764955,
4: -0.13689631932690113,
5: 0.1473096200402363,
6: 0.9564712505670716,
7: 0.956099003887811,
8: 0.33319108930207175,
9: -0.7022778825729681,
10: -0.1773916842612131,
11: 0.09485417304851751},
'index': {0: 'const',
1: 'x1',
2: 'x2',
3: 'x3',
4: 'x4',
5: 'x5',
6: 'const',
7: 'x1',
8: 'x2',
9: 'x3',
10: 'x4',
11: 'x5'},
'label': {0: '1.002***',
1: '0.941***',
2: '0.298***',
3: '-0.684***',
4: '-0.137',
5: '0.147',
6: '0.956***',
7: '0.956***',
8: '0.333***',
9: '-0.702***',
10: '-0.177',
11: '0.095'},
'lerr': {0: 0.19788416996400904,
1: 0.19972987383410545,
2: 0.0606849959013587,
3: 0.1772734289533593,
4: 0.1988122854078155,
5: 0.21870366703236832,
6: 0.2734783191688098,
7: 0.2760291042678362,
8: 0.08386739920069491,
9: 0.2449940255063039,
10: 0.27476098595116555,
11: 0.3022511162310027},
'lower': {0: 0.8039888037473053,
1: 0.7411565907082803,
2: 0.23728057391349014,
3: -0.8616787865298547,
4: -0.33570860473471664,
5: -0.07139404699213203,
6: 0.6829929313982618,
7: 0.6800698996199748,
8: 0.24932369010137684,
9: -0.947271908079272,
10: -0.45215267021237865,
11: -0.2073969431824852},
'model': {0: 'OLS',
1: 'OLS',
2: 'OLS',
3: 'OLS',
4: 'OLS',
5: 'OLS',
6: 'QuantReg',
7: 'QuantReg',
8: 'QuantReg',
9: 'QuantReg',
10: 'QuantReg',
11: 'QuantReg'},
'pvalue': {0: 1.4211692095019375e-16,
1: 4.3583690618389965e-15,
2: 6.278403727223468e-16,
3: 1.596372747840846e-11,
4: 0.17483151363955116,
5: 0.18433051296752084,
6: 4.877385844808361e-10,
7: 6.665860891682504e-10,
8: 5.476882838731488e-12,
9: 1.4240852942202845e-07,
10: 0.20303143985022934,
11: 0.5347222575215599},
'uerr': {0: 0.19788416996400904,
1: 0.19972987383410556,
2: 0.06068499590135873,
3: 0.1772734289533593,
4: 0.19881228540781554,
5: 0.21870366703236832,
6: 0.27347831916880994,
7: 0.2760291042678362,
8: 0.08386739920069491,
9: 0.2449940255063039,
10: 0.27476098595116555,
11: 0.3022511162310027},
'upper': {0: 1.1997571436753234,
1: 1.1406163383764913,
2: 0.35865056571620757,
3: -0.5071319286231362,
4: 0.0619159660809144,
5: 0.3660132870726046,
6: 1.2299495697358815,
7: 1.2321281081556472,
8: 0.41705848850276667,
9: -0.4572838570666642,
10: 0.09736930168995245,
11: 0.3971052892795202}})
fig = px.scatter(
df, y='model', x='coef', text='label', color='model',
facet_row='index',
error_x_minus='lerr', error_x='uerr',
hover_data=['coef', 'pvalue', 'lower', 'upper'],
height=600, width=800,
)
fig.update_traces(textposition='top center')
fig.update_yaxes(visible=False)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.for_each_trace(lambda t: t.update(textposition='bottom center') if t.name == 'OLS' else ())
fig.show()
I have this dataframe:
df = pd.DataFrame({'ymin': {0: 0.0,
1: 0.0,
2: 0.0,
3: 0.0,
4: 0.511,
5: 0.571,
6: 0.5329999999999999,
7: 0.5389999999999999},
'ymax': {0: 0.511,
1: 0.571,
2: 0.533,
3: 0.539,
4: 1.0,
5: 1.0,
6: 1.0,
7: 1.0},
'xmin': {0: 0.0,
1: 0.14799999999999996,
2: 0.22400000000000003,
3: 0.5239999999999999,
4: 0.0,
5: 0.14799999999999996,
6: 0.22400000000000003,
7: 0.5239999999999999},
'xmax': {0: 0.148,
1: 0.22399999999999998,
2: 0.524,
3: 1.001,
4: 0.148,
5: 0.22399999999999998,
6: 0.524,
7: 1.001},
'variable': {0: 'A', 1: 'A', 2: 'A', 3: 'A', 4: 'B', 5: 'B', 6: 'B', 7: 'B'}})
Where I plot this:
(ggplot(df, aes(ymin = "ymin", ymax = "ymax",
xmin = "xmin", xmax = "xmax", fill = "variable"))
+ geom_rect(colour = "grey", alpha=0.7))
I'm looking to change the position of the legends to the same to the positions of the plot: blue-up and red-bottom. And A always will be red and B always will be blue
There might be a more standard way to do it, but here is a quick hack to fix your problem:
Change the order of your variable
Assign colors manually (You could also look for exact color codes and replace it with the color names if it matters in your case)
df = df.assign(variable = pd.Categorical(df['variable'], ['B', 'A']))
(ggplot(df, aes(ymin = "ymin", ymax = "ymax",
xmin = "xmin", xmax = "xmax", fill = "variable"))+
geom_rect(colour = "grey", alpha=0.7)+
scale_fill_manual(values = ["blue", "red"]))
output looks like this:
You could set order of levels with df$variable <- factor(df$variable, levels = c("B","A")