Best pythonic practice to avoid repeating code? - python

I'm not sure if this questions has been asked - couldn't find a good practice one so far.
I have two python files with exactly same package import, and a number of different methods. A couple of variables vary only, and others are the same.
IF I need to make a change to one file, I have to go to the other one to apply the same changes which doesn't seem a robust way.
I really want to keep these files separate (in two files). I never had a good grasp of idea of class. Should I need to make a class in first file having all methods, loops, variables, and call it in the second, I can then overwrite the variables if need be?
This is how my first file looks like, apologies I should have spent some time to make it readable, but it's just to give you an idea about the structure. This code actually plots up a number of matplotlib figures. The second file would have different input files (CSV files) which then plot up different figures.
import csv
import datetime
import pylab
import sys
import time
from inspect import getsourcefile
from os.path import abspath
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import style
from matplotlib.backends.backend_pdf import PdfPages
def get_mul_list(*args):
return map(list, zip(*args))
def str2float(s):
if not s == '':
s = (float(s))
else:
s = np.nan
return s
def clean_nans(x, y, num_nan_gap=24):
x_clean, y_clean = [], []
cnt = 0
for _x, _y in zip(x, y):
if np.isnan(_y):
cnt += 1
if cnt == num_nan_gap:
# on the 5th nan, put it in the list to break line
x_clean.append(_x)
y_clean.append(_y)
continue
cnt = 0
x_clean.append(_x)
y_clean.append(_y)
return x_clean, y_clean
def csv_store_in_dict(filepath, mode):
csv_data = open(filepath, mode)
data = list(csv.reader(csv_data))
csv_imported_in_dict = dict(zip(data[0], get_mul_list(*data[1:])))
return csv_imported_in_dict
colors_list = ['deeppink', 'aquamarine', 'yellowgreen', 'orangered', 'darkviolet',
'darkolivegreen', 'lightskyblue', 'teal', 'seagreen', 'olivedrab', 'red', 'indigo', 'goldenrod', 'firebrick',
'slategray', 'cornflowerblue', 'darksalmon', 'blue', 'khaki', 'wheat', 'dodgerblue', 'moccasin', 'sienna',
'darkcyan']
current_py_filepath = abspath(getsourcefile(lambda: 0)) # python source file path for figure footnote
kkk_dict = csv_store_in_dict('CSV/qry_WatLvl_kkk_xlsTS_1c_v4.csv', 'r') # all WL kkk data stored in a dictionary
yyyddd_dict = csv_store_in_dict('CSV/qry_WatLvl_TimeSeries2_v2.csv', 'r') # all WL kkk data stored in a dictionary
XX_info_dict = csv_store_in_dict('CSV/XX_info.csv', 'r') # XX_name, XX_group_name, BB_Main, CC, dddd
XX_groups_chartE = ('XXH_05',
'XXH_16',
'XXH_11',
'DXX_27',
'DXX_22',
'DXX_21',
'DXX_09',
'DXX_07',
'DXX_01',
'DXX_05',)
y_range = [[5,10], # chart 1
[7,12], # chart 2
[3,8], # chart 3
[7,12], # chart 4
[5,10], # chart 5
[20,50], # chart 6
[12,22], # chart 7
[5,25], # chart 8
[10,15], # chart 9
[22,42]] # chart 10
# Date conversion
x_kkk_date = []
x_yyy_date = []
x_kkk = kkk_dict["DateTime"]
x_yyyddd = yyyddd_dict["DateTime"]
for i in x_kkk:
x_kkk_date.append(datetime.datetime.strptime(i, "%d/%m/%Y %H:%M:%S"))
for i in x_yyyddd:
x_yyy_date.append(datetime.datetime.strptime(i, "%d/%m/%Y %H:%M:%S"))
# plotting XX groups
XXs_curr_grp = []
chart_num = 1
for XX_gr_nam in XX_groups_chartE:
for count, elem in enumerate(XX_info_dict['XX_group_name']):
if elem == XX_gr_nam:
XXs_curr_grp.append(XX_info_dict['XX_name'][count])
fig = plt.figure(figsize=(14, 11))
col_ind = 0
for XX_v in XXs_curr_grp:
y_kkk = kkk_dict[XX_v]
y_yyyddd = yyyddd_dict[XX_v]
y_kkk_num = [str2float(i) for i in y_kkk]
y_yyyddd_num = [str2float(i) for i in y_yyyddd]
ind_XX = XX_info_dict["XX_name"].index(XX_v)
BB_Main = XX_info_dict["BB_Main"][ind_XX]
CC = XX_info_dict["CC"][ind_XX]
dddd = XX_info_dict["dddd"][ind_XX]
def label_pl(d_type):
label_dis = "%s (%s, %s / %s)" % (XX_v, BB_Main, CC, d_type)
return label_dis
x_kkk_date_nan_cln, y_kkk_num_nan_cln = clean_nans(x_kkk_date, y_kkk_num, 200)
plt.plot_date(x_kkk_date_nan_cln, y_kkk_num_nan_cln, '-', markeredgewidth=0,
label=label_pl("kkk data"), color=colors_list[col_ind]) # c = col_rand
plt.scatter(x_yyy_date, y_yyyddd_num, label=label_pl("yyy ddds"), marker='x', linewidths=2,
s=50, color=colors_list[col_ind])
col_ind += 1
XX_grp_title = XX_gr_nam.replace("_", "-")
plt.title("kkk Levels \n" + XX_grp_title + " Group", fontsize=20)
plt.ylabel('wwL (mmm)')
plt.legend(loc=9, ncol=2, prop={'size': 8})
plt.figtext(0.05, 0.05, current_py_filepath, horizontalalignment='left', fontsize=8) # footnote for file path
plt.figtext(0.95, 0.05, 'Chart E%s' % (chart_num,), horizontalalignment='right', fontsize=12) # chart number
plt.figtext(0.95, 0.95, datetime.date.today(), horizontalalignment='right', fontsize=8)
# FIGURE FORMATTING
myFmt = mdates.DateFormatter('%d/%m/%Y')
ax = plt.gca()
ax.xaxis.set_major_formatter(myFmt)
plt.gcf().autofmt_xdate()
ax.set_ylim(y_range[chart_num-1])
plt.grid()
fig.tight_layout()
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.15)
fig_pdf_file = "PDF/OXX_grp_page %s.pdf" % (chart_num,)
fig.savefig(fig_pdf_file)
XXs_curr_grp = []
chart_num += 1 # assumed charts numbering is the same as the order of plotting
plt.show()

No, you do not need to define a class. You need to remove the shared functions from one file, and have it import the other. An import statement can import installed python packages, but also python files. Use it like this:
# myfile.py
def f(x):
return x * 2
# main.py
import myfile
myfile.f(2)
Note that for this example, both files must be in the same directory.
However, if you would like to store myfile.py in a different directory, i.e. in this hierarchy:
my_project
----main.py
----my_modules
----myfile.py
----__init__.py
Simply create an empty __init__.py file in the 'my_modules' directory, and change your import statement to reflect import my_modules.myfile.

Related

plt.legend() blocks matplotlib to display plot

Without plt.legend() called, the plot gets displayed. With it, I just get:
<matplotlib.legend.Legend at 0x1189a404c50>
I'm working in JupyterLab, Python 3, Anaconda
I do not understand what is preventing legend from displaying. Without the last for loop iterating through xarray, i.e. if I load just one spectrum to plot, legend works fine. Any ideas? Thanks!
Here is the code:
import colour
from colour.plotting import *
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib
import scipy.integrate as integrate
from scipy import interpolate
### Read Spectrum file ###
xarray = []
yarray = []
while True:
# Separator in CSV
separator = input("Separator: [, or tab or semicolon]")
if separator in {',','comma'}:
separator = ','
elif separator in {'tab','TAB'}:
separator = '\t'
elif separator in {';','semicolon'}:
separator = ';'
else:
print("Separator must be one of the listed")
# Header in CSV
headerskip = input("Header [y/n]")
if headerskip in {'y','yes','Y','Yes','YES'}:
headerskip = 1
elif headerskip in {'n','no','N','No','NO'}:
headerskip = 0
else:
print("Header?")
# Choose CSV file
filename = input("Filename")
try:
spectrum = pd.read_csv(filename, sep = separator, header = None, skiprows = headerskip, index_col=0)
except FileNotFoundError:
print("Wrong file or file path")
# Convert to our dictionary
spec = spectrum[1].to_dict() #functional dictionary
sample_sd_data = {int(k):v for k,v in spec.items()} # changes index to integer
# Do color calculations
sd = colour.SpectralDistribution(sample_sd_data)
cmfs = colour.STANDARD_OBSERVERS_CMFS['CIE 1931 2 Degree Standard Observer']
illuminant = colour.ILLUMINANTS_SDS['D65']
XYZ = colour.sd_to_XYZ(sd, cmfs, illuminant) #tristimulus values.
print(XYZ)
xy = colour.XYZ_to_xy(XYZ) # chromaticity coordinates
x, y = xy
print(xy)
xarray.append(x)
yarray.append(y)
# Query to add another file
addfile = input("Add another spectrum [y,n]")
if addfile in {'y','yes','Y','Yes','YES'}:
print("adding another file")
elif addfile in {'n','no','N','No','NO'}:
print("done with adding files")
break
else:
print("Add another file? Breaking loop")
break
# Plotting the *CIE 1931 Chromaticity Diagram*.
# The argument *standalone=False* is passed so that the plot doesn't get
# displayed and can be used as a basis for other plots.
#plot_single_sd(sd)
print(xarray)
print(yarray)
plot_chromaticity_diagram_CIE1931(standalone=False)
# Plotting the *CIE xy* chromaticity coordinates.
for i in range(len(xarray)):
x = xarray[i]
y = yarray[i]
plt.plot(x, y, '-p', color='gray',
markersize=15, linewidth=4,
markerfacecolor='None',
markeredgecolor='gray',
markeredgewidth=2,
label=str(i))
plt.plot(.3,.3, '-o', label='test')
# Customizing plot
plt.grid(True, linestyle=':')
plt.axis('equal') # disable this to go x to zero, however it will hide 500nm label
plt.xlim(0,.8)
plt.ylim(0,.9)
#plt.legend(framealpha=1, frameon=True, handlelength=0) # set handlelength to 0 to destroy line over the symbol
You need to call the magic function %matplotlib notebook or %matplotlib inline after your imports in jupyter.

ternary plots as subplot

I want to draw multiple ternary graphs and thought to do this using matplotlib's subplot.
I'm just getting empty 'regular' plots though, not the ternary graphs I want in there. I found the usage of
figure, ax = plt.subplots()
tax = ternary.TernaryAxesSubplot(ax=ax)
so this seems to be possible, but can't really find out how to get this working. Any ideas?
Code I'm using:
I'm using a for loop as the data has columns named tria1-a, tria2-a, etc for the different triads
import ternary
import matplotlib.pyplot as plt
import pandas as pd
#configure file to import.
filename = 'somecsv.csv'
filelocation = 'location'
dfTriad = pd.read_csv(filelocation+filename)
# plot the data
scale = 33
figure, ax = plt.subplots()
tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale)
figure.set_size_inches(10, 10)
tax.set_title("Scatter Plot", fontsize=20)
tax.boundary(linewidth=2.0)
tax.gridlines(multiple=1, color="blue")
tax.legend()
tax.ticks(axis='lbr', linewidth=1, multiple=5)
tax.clear_matplotlib_ticks()
#extract the xyz columns for the triads from the full dataset
for i in range(1,6) :
key_x = 'tria'+ str(i) + '-a'
key_y = 'tria' + str(i) + '-b'
key_z = 'tria' + str(i) + '-c'
#construct dataframe from the extracted xyz columns
dfTriad_data = pd.DataFrame(dfTriad[key_x], columns=['X'])
dfTriad_data['Y'] = dfTriad[key_y]
dfTriad_data['Z'] = dfTriad[key_z]
#create list of tuples from the constructed dataframe
triad_data = [tuple(x) for x in dfTriad_data.to_records(index=False)]
plt.subplot(2, 3, i)
tax.scatter(triad_data, marker='D', color='green', label="")
tax.show()
I had the same problem and could solve it by first "going" into the subplot, then creating the ternary figure in there by giving plt.gca() as keyword argument ax:
plt.subplot(2,2,4, frameon = False)
scale = 10
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
figure, tax = ternary.figure(ax = plt.gca(), scale = scale)
#now you can use ternary normally:
tax.line(scale * np.array((0.5,0.5,0.0)), scale*np.array((0.0, 0.5, 0.5)))
tax.boundary(linewidth=1.0)
#...

Number of legend entries equals the size of data set

I'm plotting many sets of data in a for loop. The number of sets and size of sets don't have any problems plotting. When I try to add a legend, things get interesting. I get a legend, but I only get the first label to show up hundreds of times! I have one data set with 887 points, I get 887 legend entries.Here is the plot I get
You can access the .py and .xlsx files here:
https://drive.google.com/drive/folders/1QCVw2yqIHexNCvgz4QQfJQDGYql1hGW8?usp=sharing
Here is the code that is generating the plot.
# Temperature Data plotting
=================================================
#initialize figure
plt.figure(figsize=(11,8))
Color = 'C'
Marks = '*','o','+','x','s','d','.'
nm = len(Marks)
q = 0 # Marks counter
c = 0 # color counter
for k in range(0,nt):
style = 'C' + str(c) + Marks[q]
test = 'T' + str(k)
plt.plot([t+t_adjust[k]],[Temps[:,k]],style,label=test)
#, label = 'test'
c += 1
if(c==6):
c = 9
if(c==10):
c = 0
q += 1
if(k > nt-10):
q = nm - 1
# Formatting Figure
#names = '1','2','3','4','5'
#name1 = '1'
#pylab.legend([name1])
#from collections import OrderedDict
#import matplotlib.pyplot as plt
#handles, labels = plt.gca().get_legend_handles_labels()
#by_label = OrderedDict(zip(labels, handles))
#plt.legend(by_label.values(), by_label.keys())
plt.legend(loc = 'upper right')
plt.show()
# x axis limits, in seconds
plt.xlim(0,60)
plt.xlabel('t (s)')
plt.ylabel('T (deg C)')
FigTitle = (oper_name + '; ' + str(pres_val) + pres_unit + '; d=' +
str(diam_val) + diam_unit + '; H=' + str(dist_val) + dist_unit)
plt.title(FigTitle)
# End Temperature Data Plotting
==============================================
I have 14 sets of data, with 887 points each. There is clearly more than 14 legend entries. Not sure why its somehow referencing the length of data or something. I found this (code below) to find the handles and labels, but I need them to be assigned the style name for each data set instead of the first style name for the length of data.
#from collections import OrderedDict
#import matplotlib.pyplot as plt
#handles, labels = plt.gca().get_legend_handles_labels()
#by_label = OrderedDict(zip(labels, handles))
#plt.legend(by_label.values(), by_label.keys())
Hard to say without having a look at the data, but you can always control what goes into the legend manually like so:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0., 2*np.pi, 101, endpoint=True)
lns = []
for i in range(1, 10):
for j in range(10):
ln = plt.plot(x, j*np.sin(x/i), label="series i={:d}".format(i))
lns += ln # note plt.plot returns a list of entities
plt.legend(lns, [l.get_label() for l in lns], loc="best")
plt.show()

How to plot data from multiple files in a loop

I have a more than 1000 .csv files (data_1.csv......data1000.csv), each containing X and Y values!
x1 y1 x2 y2
5.0 60 5.5 500
6.0 70 6.5 600
7.0 80 7.5 700
8.0 90 8.5 800
9.0 100 9.5 900
I have made a subplot program in python which can give two plots (plot1 - X1vsY1, Plot2 - X2vsY2) at a time using one file.
I need help in looping all the files, (open a file, read it, plot it, pick another file, open it, read it, plot it, ... until all the files in a folder get plotted)
I have the following code:
import pandas as pd
import matplotlib.pyplot as plt
df1=pd.read_csv("data_csv",header=1,sep=',')
fig = plt.figure()
plt.subplot(2, 1, 1)
plt.plot(df1.iloc[:,[1]],df1.iloc[:,[2]])
plt.subplot(2, 1, 2)
plt.plot(df1.iloc[:,[3]],df1.iloc[:,[4]])
plt.show()
How can this be accomplished more efficiently?
You can generate a list of filenames using glob and then plot them in a for loop.
import glob
import pandas as pd
import matplotlib.pyplot as plt
files = glob.glob(# file pattern something like '*.csv')
for file in files:
df1=pd.read_csv(file,header=1,sep=',')
fig = plt.figure()
plt.subplot(2, 1, 1)
plt.plot(df1.iloc[:,[1]],df1.iloc[:,[2]])
plt.subplot(2, 1, 2)
plt.plot(df1.iloc[:,[3]],df1.iloc[:,[4]])
plt.show() # this wil stop the loop until you close the plot
I used NetCDF(.nc) just in case anyone is interested in using NetCDF data. Also, you could replace it with .txt too, the idea is the same. I used this for a contour plot loop.
path_to_folder='#type the path to the files'
count=0
fig = plt.figure(figsize=(10,5))
files = []
for i in os.listdir(path_to_folder):
if i.endswith('.nc'):
count=count+1
files.append(open(i))
data=xr.open_dataset(i)
prec=data['tp']
plt.subplot(1, 2, count) # change 1 and 2 to the shape you want
prec.groupby('time.month').mean(dim=('time','longitude')).T.plot.contourf(cmap='Purples') *#this is to plot contour plot but u can replace with any plot command
print(files)
plt.savefig('try,png',dpi=500,orientation='landscape',format='png')
Here is the basic setup for what am using here at work. This code will plot the data from each file and through each file separately. This will work on any number of files as long as column names remain the same. Just direct it to the proper folder.
import os
import csv
def graphWriterIRIandRut():
m = 0
List1 = []
List2 = []
List3 = []
List4 = []
fileList = []
for file in os.listdir(os.getcwd()):
fileList.append(file)
while m < len(fileList):
for col in csv.DictReader(open(fileList[m],'rU')):
List1.append(col['Col 1 Name'])
List2.append(col['Col 2 Name'])
List3.append(col['Col 3 Name'])
List4.append(col['Col 4 Name'])
plt.subplot(2, 1, 1)
plt.grid(True)
colors = np.random.rand(n)
plt.plot(List1,List2,c=colors)
plt.tick_params(axis='both', which='major', labelsize=8)
plt.subplot(2, 1, 2)
plt.grid(True)
colors = np.random.rand(n)
plt.plot(List1,List3,c=colors)
plt.tick_params(axis='both', which='major', labelsize=8)
m = m + 1
continue
plt.show()
plt.gcf().clear()
plt.close('all')
# plotting all the file data and saving the plots
import os
import csv
import matplotlib.pyplot as plt
def graphWriterIRIandRut():
m = 0
List1 = []
List2 = []
List3 = []
List4 = []
fileList = []
for file in os.listdir(os.getcwd()):
fileList.append(file)
while m < len(fileList):
for col in csv.DictReader(open(fileList[m],'rU')):
List1.append(col['x1'])
List2.append(col['y1'])
List3.append(col['x2'])
List4.append(col['y2'])
plt.subplot(2, 1, 1)
plt.grid(True)
# colors = np.random.rand(2)
plt.plot(List1,List2,c=colors)
plt.tick_params(axis='both', which='major', labelsize=8)
plt.subplot(2, 1, 2)
plt.grid(True)
# colors = np.random.rand(2)
plt.plot(List1,List3,c=colors)
plt.tick_params(axis='both', which='major', labelsize=8)
m = m + 1
continue
plt.show()
plt.gcf().clear()
plt.close('all')
What we want to do is for each iteration, or file, create a new empty list. So for each iteration the data will be plotted, but once that data has been plotted a new empty list will be created, and plotted. Once all the data from each file has been plotted, then you want to finally to plt.show() which will show all the plots together. Here is a link to a similar problem I was having: Traceback lines on plot of multiple files. Goog luck!
import csv
import matplotlib.pyplot as plt
def graphWriter():
for file in os.listdir(os.getcwd()):
List1 = []
List2 = []
List3 = []
List4 = []
with open(filename, 'r') as file:
for col in csv.DictReader(file):
List1.append(col['x1'])
List2.append(col['y1'])
List3.append(col['x2'])
List4.append(col['y2'])
plt.subplot(2, 1, 1)
plt.grid(True)
colors = np.random.rand(2)
plt.plot(List1,List2,c=colors)
plt.tick_params(axis='both', which='major', labelsize=8)
plt.subplot(2, 1, 2)
plt.grid(True)
colors = np.random.rand(2)
plt.plot(List1,List3,c=colors)
plt.tick_params(axis='both', which='major', labelsize=8)
plt.show()
plt.gcf().clear()
plt.close('all')
If for some reason #Neill Herbst answer didnt work as expected (i consider the easiest way) I run with a problem reading the files I rearrenged the code that worked for me
import glob
import pandas as pd
import matplotlib.pyplot as plt
os.chdir(r'path')
for file in glob.glob("*.csv")::
df1=pd.read_csv(file,header=1,sep=',')
fig = plt.figure()
plt.subplot(2, 1, 1)
plt.plot(df1.iloc[:,[1]],df1.iloc[:,[2]])
plt.subplot(2, 1, 2)
plt.plot(df1.iloc[:,[3]],df1.iloc[:,[4]])
plt.show() # plot one csv when you close it, plots next one
#plt.show <------ if u want to see all the plots in different windows
Using p = Path(...): p → WindowsPath('so_data/files')
files = p.rglob(...) yields all files matching the pattern
file[0] → WindowsPath('so_data/files/data_1.csv')
p.parent / 'plots' / f'{file.stem}.png' → WindowsPath('so_data/plots/data_1.png')
p.parent → WindowsPath('so_data')
file.stem → data_1
This assumes all directories exist. Directory creation / checking is not included.
This example uses pandas, as does the OP.
Plotted with pandas.DataFrame.plot, which uses matplotlib as the default backend.
Use .iloc to specify the columns, and then x=0 will always be the x-axis data, based on the given example data.
Tested in python 3.8.11, pandas 1.3.2, matplotlib 3.4.3
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
p = Path('so_data/files') # specify the path to the files
files = p.rglob('data_*.csv') # generator for all files based on rglob pattern
for file in files:
df = pd.read_csv(file, header=0, sep=',') # specify header row and separator as needed
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 5))
df.iloc[:, [0, 1]].plot(x=0, ax=ax1) # plot 1st x/y pair; assumes x data is at position 0
df.iloc[:, [2, 3]].plot(x=0, ax=ax2) # plot 2nd x/y pair; assumes x data is at position 0
fig.savefig(p.parent / 'plots' / f'{file.stem}.png')
plt.close(fig) # close each figure, otherwise they stay in memory
Sample Data
This is for testing the plotting code
Create a so_data/files directory manually.
df = pd.DataFrame({'x1': [5.0, 6.0, 7.0, 8.0, 9.0], 'y1': [60, 70, 80, 90, 100], 'x2': [5.5, 6.5, 7.5, 8.5, 9.5], 'y2': [500, 600, 700, 800, 900]})
for x in range(1, 1001):
df.to_csv(f'so_data/files/data_{x}.csv', index=False)
Alternate Answer
This answer addresses cases where there are many consecutive pairs of x/y columns
df.column creates an array of columns, that can be chunked into pairs
For consecutive column pairs, this answer works
list(zip(*[iter(df.columns)]*2)) → [('x1', 'y1'), ('x2', 'y2')]
If necessary, use some other pattern to create pairs of columns
Use .loc, since there will be column names, instead of .iloc for column indices.
p = Path('so_data/files')
files = p.rglob('data_*.csv')
for file in files:
df = pd.read_csv(file, header=0, sep=',')
col_pair = list(zip(*[iter(df.columns)]*2)) # extract column pairs
fig, axes = plt.subplots(len(col_pair), 1) # a number of subplots based on number of col_pairs
axes = axes.ravel() # flatten the axes if necessary
for cols, ax in zip(col_pair, axes):
df.loc[:, cols].plot(x=0, ax=ax) # assumes x data is at position 0
fig.savefig(p.parent / 'plots' / f'{file.stem}.png')
plt.close(fig)

No exponential form of the z-axis in matplotlib-3D-plots

I have a similar problem as described in How to prevent numbers being changed to exponential form in Python matplotlib figure:
I don't want that (in my special case) weird scientific formatting of the axis. My problem is different as I have this problem at my z-Axis. For 2-D plots I can use ax.get_yaxis().get_major_formatter().set_useOffset(False). And there is no function ax.get_zaxis()
What do I use to format my z-Axis the same way?
EDIT: Example:
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import sys
import matplotlib
import matplotlib.pyplot as pyplot
def func(xi, ti):
res = 10e3 + np.cos(ti) * np.sin(xi)
return res
if __name__ == '__main__':
timeSpacing = 20
timeStart = 0
timeEnd = 1
time = np.linspace(timeStart, timeEnd, timeSpacing)
widthSpacing = 50
widthStart = 0
widthEnd = 3
width = np.linspace(widthStart, widthEnd, widthSpacing)
resList = []
matplotlib.rcParams['legend.fontsize'] = 10
fig = pyplot.figure()
ax = fig.gca(projection = '3d')
for i, item in enumerate(time):
ti = [item for t in width]
res = func(width, ti)
ax.plot(width, ti, res, 'b')
ax.set_xlabel('x')
ax.set_ylabel('t')
ax.set_zlabel('f(x,t)')
pyplot.show()
As you say, there is no get_zaxis() method. But, fortunately, there is zaxis field (so don't add ()). There are also xaxis and yaxis fields, so you can use all of those uniformly instead of get_...axis() if you like.
For example:
if __name__ == '__main__':
...
ax = fig.gca(projection = '3d')
ax.zaxis.get_major_formatter().set_useOffset(False) # here
for i, item in enumerate(time):
...
and the end result should look something like this:
As you can see, for large numbers it might not look so well...

Categories