Related
I have this dataframe with euclidean distances:
import pandas as pd
df = pd.DataFrame({
'O1': [0.0, 1.7, 1.4, 0.4, 2.2, 3.7, 5.2, 0.2, 4.3, 6.8, 6.0],
'O2': [1.7, 0.0, 1.0, 2.0, 1.3, 2.6, 4.5, 1.8, 3.2, 5.9, 5.2],
'O3': [1.4, 1.0, 0.0, 1.7, 0.9, 2.4, 4.1, 1.5, 3.0, 5.5, 4.8],
'O4': [0.4, 2.0, 1.7, 0.0, 2.6, 4.0, 5.5, 0.3, 4.6, 7.1, 6.3],
'O5': [2.2, 1.3, 0.9, 2.6, 0.0, 1.7, 3.4, 2.4, 2.1, 4.8, 4.1],
'O6': [3.7, 2.6, 2.4, 4.0, 1.7, 0.0, 2.0, 3.8, 1.6, 3.3, 2.7],
'O7': [5.2, 4.5, 4.1, 5.5, 3.4, 2.0, 0.0, 5.4, 2.5, 1.6, 0.9],
'O8': [0.2, 1.8, 1.5, 0.3, 2.4, 3.8, 5.4, 0.0, 4.4, 6.9, 6.1],
'O9': [4.3, 3.2, 3.0, 4.6, 2.1, 1.6, 2.5, 4.4, 0.0, 3.4, 2.9],
'O10':[6.8, 5.9, 5.5, 7.1, 4.8, 3.3, 1.6, 6.9, 3.4, 0.0, 1.0],
'O11': [6.0, 5.2, 4.8, 6.3, 4.1, 2.7, 0.9, 6.1, 2.9, 1.0, 0.0]
})
Whereas O1, O2, O3, O4, O5, O6, O7, O8 is class 0 and O9, O10 and O11 is class 1.
I want to change the dataframe above to a dataframe with columns: x, y and class. So I am able to split into train and test sets to then fit a simple classifier.
I am confused how I can achieve dataframe described above. How is this performed in python? Is it possible?
Steps afterwards when dataframe is achieved:
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
import seaborn as sns
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25)
model = GaussianNB()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)
sns.scatterplot(x = X_test['x'], y = X_test['y'], hue = y_pred)
You mainly want to include the point name as an additional column in the dataframe. Here I am using point indices as x and y:
import pandas as pd
df = pd.DataFrame({
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
1: [0.0, 1.7, 1.4, 0.4, 2.2, 3.7, 5.2, 0.2, 4.3, 6.8, 6.0],
2: [1.7, 0.0, 1.0, 2.0, 1.3, 2.6, 4.5, 1.8, 3.2, 5.9, 5.2],
3: [1.4, 1.0, 0.0, 1.7, 0.9, 2.4, 4.1, 1.5, 3.0, 5.5, 4.8],
4: [0.4, 2.0, 1.7, 0.0, 2.6, 4.0, 5.5, 0.3, 4.6, 7.1, 6.3],
5: [2.2, 1.3, 0.9, 2.6, 0.0, 1.7, 3.4, 2.4, 2.1, 4.8, 4.1],
6: [3.7, 2.6, 2.4, 4.0, 1.7, 0.0, 2.0, 3.8, 1.6, 3.3, 2.7],
7: [5.2, 4.5, 4.1, 5.5, 3.4, 2.0, 0.0, 5.4, 2.5, 1.6, 0.9],
8: [0.2, 1.8, 1.5, 0.3, 2.4, 3.8, 5.4, 0.0, 4.4, 6.9, 6.1],
9: [4.3, 3.2, 3.0, 4.6, 2.1, 1.6, 2.5, 4.4, 0.0, 3.4, 2.9],
10: [6.8, 5.9, 5.5, 7.1, 4.8, 3.3, 1.6, 6.9, 3.4, 0.0, 1.0],
11: [6.0, 5.2, 4.8, 6.3, 4.1, 2.7, 0.9, 6.1, 2.9, 1.0, 0.0]
})
That allows you to reshape the dataframe to your desired form:
model_df = df.melt(id_vars='x', var_name='y', value_name='distance')
Finally, define a class e.g. using:
def assign_class(x):
return 0 if x <= 8 else 1
model_df["class_x"] = model_df["x"].apply(assign_class),
model_df["class_y"] = model_df["y"].apply(assign_class)
This will give you a dataframe that you can pass to the model. Note that the input matrix is symmetric, so you may want to only keep unique records (drop [y, x] if you already have [x, y]).
from pylab import *
def x(t) :
if 0 <= t < 8 :
return(2*t)
elif 8 <= t < 20 :
return(t**3)
t = arange(5.0, 20, 0.3)
print([i for i in t])
Output is
[5.0, 5.3, 5.6, 5.8999999999999995, 6.199999999999999, 6.499999999999999, 6.799999999999999, 7.099999999999999, 7.399999999999999, 7.699999999999998, 7.999999999999998, 8.299999999999997, 8.599999999999998, 8.899999999999999, 9.199999999999998, 9.499999999999996, 9.799999999999997, 10.099999999999998, 10.399999999999997, 10.699999999999996, 10.999999999999996, 11.299999999999997, 11.599999999999996, 11.899999999999995, 12.199999999999996, 12.499999999999996, 12.799999999999995, 13.099999999999994, 13.399999999999995, 13.699999999999996, 13.999999999999995, 14.299999999999994, 14.599999999999994, 14.899999999999995, 15.199999999999994, 15.499999999999993, 15.799999999999994, 16.099999999999994, 16.39999999999999, 16.699999999999992, 16.999999999999993, 17.299999999999994, 17.599999999999994, 17.89999999999999, 18.199999999999992, 18.499999999999993, 18.79999999999999, 19.09999999999999, 19.39999999999999, 19.699999999999992, 19.999999999999993]
What I want is
[5.0, 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, 7.4, 7.7, 8.0, so on]
When it comes to 8.0, my output is 7.999999999999998 < 8.
So wrong answer.
I want 8.0.
So that I can plot function.
plot(t, array([x(i) for i in t]))
I guess a simple rounding off is all you need.
Change the last line to this:
print([round(i,1) for i in t])
Output:
[5.0, 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, 7.4, 7.7, 8.0, 8.3, 8.6, 8.9, 9.2, 9.5, 9.8, 10.1, 10.4, 10.7, 11.0, 11.3, 11.6, 11.9, 12.2, 12.5, 12.8, 13.1, 13.4, 13.7, 14.0, 14.3, 14.6, 14.9, 15.2, 15.5, 15.8, 16.1, 16.4, 16.7, 17.0, 17.3, 17.6, 17.9, 18.2, 18.5, 18.8, 19.1, 19.4, 19.7]
So in your case the code becomes something like:
from pylab import *
def x(t) :
if 0 <= t < 8 :
return(2*t)
elif 8 <= t < 20 :
return(t**3)
t = arange(5.0, 20, 0.3)
t = [round(i,1) for i in t]
print(t)
Now you can use this t and get the following plot:
I'm trying to take advantage of NumPy broadcasting and backend array computations to significantly speed up this function. Unfortunately, it doesn't scale so well so I'm hoping to greatly improve the performance of this. Right now the code isn't properly utilizing broadcasting for the computations.
I'm using WGCNA's bicor function as a gold standard as this is the fastest implementation I know of at the moment. The Python version outputs the same results as the R function.
# ==============================================================================
# Imports
# ==============================================================================
# Built-ins
import os, sys, time, multiprocessing
# 3rd party
import numpy as np
import pandas as pd
# ==============================================================================
# R Imports
# ==============================================================================
from rpy2 import robjects, rinterface
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
pandas2ri.activate()
R = robjects.r
NULL = robjects.rinterface.NULL
rinterface.set_writeconsole_regular(None)
WGCNA = importr("WGCNA")
# Python
def _biweight_midcorrelation(a, b):
a_median = np.median(a)
b_median = np.median(b)
# Median absolute deviation
a_mad = np.median(np.abs(a - a_median))
b_mad = np.median(np.abs(b - b_median))
u = (a - a_median) / (9 * a_mad)
v = (b - b_median) / (9 * b_mad)
w_a = np.square(1 - np.square(u)) * ((1 - np.abs(u)) > 0)
w_b = np.square(1 - np.square(v)) * ((1 - np.abs(v)) > 0)
a_item = (a - a_median) * w_a
b_item = (b - b_median) * w_b
return (a_item * b_item).sum() / (
np.sqrt(np.square(a_item).sum()) *
np.sqrt(np.square(b_item).sum()))
def biweight_midcorrelation(X):
return X.corr(method=_biweight_midcorrelation)
# # OLD IMPLEMENTATION
# def biweight_midcorrelation(X):
# median = X.median()
# mad = (X - median).abs().median()
# U = (X - median) / (9 * mad)
# adjacency = np.square(1 - np.square(U)) * ((1 - U.abs()) > 0)
# estimator = (X - median) * adjacency
# bicor_matrix = np.empty((X.shape[1], X.shape[1]), dtype=float)
# for i, ac in enumerate(estimator):
# for j, bc in enumerate(estimator):
# a = estimator[ac]
# b = estimator[bc]
# c = (a * b).sum() / (
# np.sqrt(np.square(a).sum()) * np.sqrt(np.square(b).sum()))
# bicor_matrix[i, j] = c
# bicor_matrix[j, i] = c
# return pd.DataFrame(bicor_matrix, index=X.columns, columns=X.columns)
# R
def biweight_midcorrelation_r_wrapper(X, n_jobs=-1, r_package=None):
"""
WGCNA: bicor
function (x, y = NULL, robustX = TRUE, robustY = TRUE, use = "all.obs",
maxPOutliers = 1, qu <...> dian absolute deviation, or zero variance."))
"""
if r_package is None:
r_package = importr("WGCNA")
if n_jobs == -1:
n_jobs = multiprocessing.cpu_count()
labels = X.columns
r_df_sim = r_package.bicor(pandas2ri.py2ri(X), nThreads=n_jobs)
df_bicor = pd.DataFrame(pandas2ri.ri2py(r_df_sim), index=labels, columns=labels)
return df_bicor
# X.shape = (150,4)
X = pd.DataFrame({'sepal_length': {'iris_0': 5.1, 'iris_1': 4.9, 'iris_2': 4.7, 'iris_3': 4.6, 'iris_4': 5.0, 'iris_5': 5.4, 'iris_6': 4.6, 'iris_7': 5.0, 'iris_8': 4.4, 'iris_9': 4.9, 'iris_10': 5.4, 'iris_11': 4.8, 'iris_12': 4.8, 'iris_13': 4.3, 'iris_14': 5.8, 'iris_15': 5.7, 'iris_16': 5.4, 'iris_17': 5.1, 'iris_18': 5.7, 'iris_19': 5.1, 'iris_20': 5.4, 'iris_21': 5.1, 'iris_22': 4.6, 'iris_23': 5.1, 'iris_24': 4.8, 'iris_25': 5.0, 'iris_26': 5.0, 'iris_27': 5.2, 'iris_28': 5.2, 'iris_29': 4.7, 'iris_30': 4.8, 'iris_31': 5.4, 'iris_32': 5.2, 'iris_33': 5.5, 'iris_34': 4.9, 'iris_35': 5.0, 'iris_36': 5.5, 'iris_37': 4.9, 'iris_38': 4.4, 'iris_39': 5.1, 'iris_40': 5.0, 'iris_41': 4.5, 'iris_42': 4.4, 'iris_43': 5.0, 'iris_44': 5.1, 'iris_45': 4.8, 'iris_46': 5.1, 'iris_47': 4.6, 'iris_48': 5.3, 'iris_49': 5.0, 'iris_50': 7.0, 'iris_51': 6.4, 'iris_52': 6.9, 'iris_53': 5.5, 'iris_54': 6.5, 'iris_55': 5.7, 'iris_56': 6.3, 'iris_57': 4.9, 'iris_58': 6.6, 'iris_59': 5.2, 'iris_60': 5.0, 'iris_61': 5.9, 'iris_62': 6.0, 'iris_63': 6.1, 'iris_64': 5.6, 'iris_65': 6.7, 'iris_66': 5.6, 'iris_67': 5.8, 'iris_68': 6.2, 'iris_69': 5.6, 'iris_70': 5.9, 'iris_71': 6.1, 'iris_72': 6.3, 'iris_73': 6.1, 'iris_74': 6.4, 'iris_75': 6.6, 'iris_76': 6.8, 'iris_77': 6.7, 'iris_78': 6.0, 'iris_79': 5.7, 'iris_80': 5.5, 'iris_81': 5.5, 'iris_82': 5.8, 'iris_83': 6.0, 'iris_84': 5.4, 'iris_85': 6.0, 'iris_86': 6.7, 'iris_87': 6.3, 'iris_88': 5.6, 'iris_89': 5.5, 'iris_90': 5.5, 'iris_91': 6.1, 'iris_92': 5.8, 'iris_93': 5.0, 'iris_94': 5.6, 'iris_95': 5.7, 'iris_96': 5.7, 'iris_97': 6.2, 'iris_98': 5.1, 'iris_99': 5.7, 'iris_100': 6.3, 'iris_101': 5.8, 'iris_102': 7.1, 'iris_103': 6.3, 'iris_104': 6.5, 'iris_105': 7.6, 'iris_106': 4.9, 'iris_107': 7.3, 'iris_108': 6.7, 'iris_109': 7.2, 'iris_110': 6.5, 'iris_111': 6.4, 'iris_112': 6.8, 'iris_113': 5.7, 'iris_114': 5.8, 'iris_115': 6.4, 'iris_116': 6.5, 'iris_117': 7.7, 'iris_118': 7.7, 'iris_119': 6.0, 'iris_120': 6.9, 'iris_121': 5.6, 'iris_122': 7.7, 'iris_123': 6.3, 'iris_124': 6.7, 'iris_125': 7.2, 'iris_126': 6.2, 'iris_127': 6.1, 'iris_128': 6.4, 'iris_129': 7.2, 'iris_130': 7.4, 'iris_131': 7.9, 'iris_132': 6.4, 'iris_133': 6.3, 'iris_134': 6.1, 'iris_135': 7.7, 'iris_136': 6.3, 'iris_137': 6.4, 'iris_138': 6.0, 'iris_139': 6.9, 'iris_140': 6.7, 'iris_141': 6.9, 'iris_142': 5.8, 'iris_143': 6.8, 'iris_144': 6.7, 'iris_145': 6.7, 'iris_146': 6.3, 'iris_147': 6.5, 'iris_148': 6.2, 'iris_149': 5.9}, 'sepal_width': {'iris_0': 3.5, 'iris_1': 3.0, 'iris_2': 3.2, 'iris_3': 3.1, 'iris_4': 3.6, 'iris_5': 3.9, 'iris_6': 3.4, 'iris_7': 3.4, 'iris_8': 2.9, 'iris_9': 3.1, 'iris_10': 3.7, 'iris_11': 3.4, 'iris_12': 3.0, 'iris_13': 3.0, 'iris_14': 4.0, 'iris_15': 4.4, 'iris_16': 3.9, 'iris_17': 3.5, 'iris_18': 3.8, 'iris_19': 3.8, 'iris_20': 3.4, 'iris_21': 3.7, 'iris_22': 3.6, 'iris_23': 3.3, 'iris_24': 3.4, 'iris_25': 3.0, 'iris_26': 3.4, 'iris_27': 3.5, 'iris_28': 3.4, 'iris_29': 3.2, 'iris_30': 3.1, 'iris_31': 3.4, 'iris_32': 4.1, 'iris_33': 4.2, 'iris_34': 3.1, 'iris_35': 3.2, 'iris_36': 3.5, 'iris_37': 3.6, 'iris_38': 3.0, 'iris_39': 3.4, 'iris_40': 3.5, 'iris_41': 2.3, 'iris_42': 3.2, 'iris_43': 3.5, 'iris_44': 3.8, 'iris_45': 3.0, 'iris_46': 3.8, 'iris_47': 3.2, 'iris_48': 3.7, 'iris_49': 3.3, 'iris_50': 3.2, 'iris_51': 3.2, 'iris_52': 3.1, 'iris_53': 2.3, 'iris_54': 2.8, 'iris_55': 2.8, 'iris_56': 3.3, 'iris_57': 2.4, 'iris_58': 2.9, 'iris_59': 2.7, 'iris_60': 2.0, 'iris_61': 3.0, 'iris_62': 2.2, 'iris_63': 2.9, 'iris_64': 2.9, 'iris_65': 3.1, 'iris_66': 3.0, 'iris_67': 2.7, 'iris_68': 2.2, 'iris_69': 2.5, 'iris_70': 3.2, 'iris_71': 2.8, 'iris_72': 2.5, 'iris_73': 2.8, 'iris_74': 2.9, 'iris_75': 3.0, 'iris_76': 2.8, 'iris_77': 3.0, 'iris_78': 2.9, 'iris_79': 2.6, 'iris_80': 2.4, 'iris_81': 2.4, 'iris_82': 2.7, 'iris_83': 2.7, 'iris_84': 3.0, 'iris_85': 3.4, 'iris_86': 3.1, 'iris_87': 2.3, 'iris_88': 3.0, 'iris_89': 2.5, 'iris_90': 2.6, 'iris_91': 3.0, 'iris_92': 2.6, 'iris_93': 2.3, 'iris_94': 2.7, 'iris_95': 3.0, 'iris_96': 2.9, 'iris_97': 2.9, 'iris_98': 2.5, 'iris_99': 2.8, 'iris_100': 3.3, 'iris_101': 2.7, 'iris_102': 3.0, 'iris_103': 2.9, 'iris_104': 3.0, 'iris_105': 3.0, 'iris_106': 2.5, 'iris_107': 2.9, 'iris_108': 2.5, 'iris_109': 3.6, 'iris_110': 3.2, 'iris_111': 2.7, 'iris_112': 3.0, 'iris_113': 2.5, 'iris_114': 2.8, 'iris_115': 3.2, 'iris_116': 3.0, 'iris_117': 3.8, 'iris_118': 2.6, 'iris_119': 2.2, 'iris_120': 3.2, 'iris_121': 2.8, 'iris_122': 2.8, 'iris_123': 2.7, 'iris_124': 3.3, 'iris_125': 3.2, 'iris_126': 2.8, 'iris_127': 3.0, 'iris_128': 2.8, 'iris_129': 3.0, 'iris_130': 2.8, 'iris_131': 3.8, 'iris_132': 2.8, 'iris_133': 2.8, 'iris_134': 2.6, 'iris_135': 3.0, 'iris_136': 3.4, 'iris_137': 3.1, 'iris_138': 3.0, 'iris_139': 3.1, 'iris_140': 3.1, 'iris_141': 3.1, 'iris_142': 2.7, 'iris_143': 3.2, 'iris_144': 3.3, 'iris_145': 3.0, 'iris_146': 2.5, 'iris_147': 3.0, 'iris_148': 3.4, 'iris_149': 3.0}, 'petal_length': {'iris_0': 1.4, 'iris_1': 1.4, 'iris_2': 1.3, 'iris_3': 1.5, 'iris_4': 1.4, 'iris_5': 1.7, 'iris_6': 1.4, 'iris_7': 1.5, 'iris_8': 1.4, 'iris_9': 1.5, 'iris_10': 1.5, 'iris_11': 1.6, 'iris_12': 1.4, 'iris_13': 1.1, 'iris_14': 1.2, 'iris_15': 1.5, 'iris_16': 1.3, 'iris_17': 1.4, 'iris_18': 1.7, 'iris_19': 1.5, 'iris_20': 1.7, 'iris_21': 1.5, 'iris_22': 1.0, 'iris_23': 1.7, 'iris_24': 1.9, 'iris_25': 1.6, 'iris_26': 1.6, 'iris_27': 1.5, 'iris_28': 1.4, 'iris_29': 1.6, 'iris_30': 1.6, 'iris_31': 1.5, 'iris_32': 1.5, 'iris_33': 1.4, 'iris_34': 1.5, 'iris_35': 1.2, 'iris_36': 1.3, 'iris_37': 1.4, 'iris_38': 1.3, 'iris_39': 1.5, 'iris_40': 1.3, 'iris_41': 1.3, 'iris_42': 1.3, 'iris_43': 1.6, 'iris_44': 1.9, 'iris_45': 1.4, 'iris_46': 1.6, 'iris_47': 1.4, 'iris_48': 1.5, 'iris_49': 1.4, 'iris_50': 4.7, 'iris_51': 4.5, 'iris_52': 4.9, 'iris_53': 4.0, 'iris_54': 4.6, 'iris_55': 4.5, 'iris_56': 4.7, 'iris_57': 3.3, 'iris_58': 4.6, 'iris_59': 3.9, 'iris_60': 3.5, 'iris_61': 4.2, 'iris_62': 4.0, 'iris_63': 4.7, 'iris_64': 3.6, 'iris_65': 4.4, 'iris_66': 4.5, 'iris_67': 4.1, 'iris_68': 4.5, 'iris_69': 3.9, 'iris_70': 4.8, 'iris_71': 4.0, 'iris_72': 4.9, 'iris_73': 4.7, 'iris_74': 4.3, 'iris_75': 4.4, 'iris_76': 4.8, 'iris_77': 5.0, 'iris_78': 4.5, 'iris_79': 3.5, 'iris_80': 3.8, 'iris_81': 3.7, 'iris_82': 3.9, 'iris_83': 5.1, 'iris_84': 4.5, 'iris_85': 4.5, 'iris_86': 4.7, 'iris_87': 4.4, 'iris_88': 4.1, 'iris_89': 4.0, 'iris_90': 4.4, 'iris_91': 4.6, 'iris_92': 4.0, 'iris_93': 3.3, 'iris_94': 4.2, 'iris_95': 4.2, 'iris_96': 4.2, 'iris_97': 4.3, 'iris_98': 3.0, 'iris_99': 4.1, 'iris_100': 6.0, 'iris_101': 5.1, 'iris_102': 5.9, 'iris_103': 5.6, 'iris_104': 5.8, 'iris_105': 6.6, 'iris_106': 4.5, 'iris_107': 6.3, 'iris_108': 5.8, 'iris_109': 6.1, 'iris_110': 5.1, 'iris_111': 5.3, 'iris_112': 5.5, 'iris_113': 5.0, 'iris_114': 5.1, 'iris_115': 5.3, 'iris_116': 5.5, 'iris_117': 6.7, 'iris_118': 6.9, 'iris_119': 5.0, 'iris_120': 5.7, 'iris_121': 4.9, 'iris_122': 6.7, 'iris_123': 4.9, 'iris_124': 5.7, 'iris_125': 6.0, 'iris_126': 4.8, 'iris_127': 4.9, 'iris_128': 5.6, 'iris_129': 5.8, 'iris_130': 6.1, 'iris_131': 6.4, 'iris_132': 5.6, 'iris_133': 5.1, 'iris_134': 5.6, 'iris_135': 6.1, 'iris_136': 5.6, 'iris_137': 5.5, 'iris_138': 4.8, 'iris_139': 5.4, 'iris_140': 5.6, 'iris_141': 5.1, 'iris_142': 5.1, 'iris_143': 5.9, 'iris_144': 5.7, 'iris_145': 5.2, 'iris_146': 5.0, 'iris_147': 5.2, 'iris_148': 5.4, 'iris_149': 5.1}, 'petal_width': {'iris_0': 0.2, 'iris_1': 0.2, 'iris_2': 0.2, 'iris_3': 0.2, 'iris_4': 0.2, 'iris_5': 0.4, 'iris_6': 0.3, 'iris_7': 0.2, 'iris_8': 0.2, 'iris_9': 0.1, 'iris_10': 0.2, 'iris_11': 0.2, 'iris_12': 0.1, 'iris_13': 0.1, 'iris_14': 0.2, 'iris_15': 0.4, 'iris_16': 0.4, 'iris_17': 0.3, 'iris_18': 0.3, 'iris_19': 0.3, 'iris_20': 0.2, 'iris_21': 0.4, 'iris_22': 0.2, 'iris_23': 0.5, 'iris_24': 0.2, 'iris_25': 0.2, 'iris_26': 0.4, 'iris_27': 0.2, 'iris_28': 0.2, 'iris_29': 0.2, 'iris_30': 0.2, 'iris_31': 0.4, 'iris_32': 0.1, 'iris_33': 0.2, 'iris_34': 0.2, 'iris_35': 0.2, 'iris_36': 0.2, 'iris_37': 0.1, 'iris_38': 0.2, 'iris_39': 0.2, 'iris_40': 0.3, 'iris_41': 0.3, 'iris_42': 0.2, 'iris_43': 0.6, 'iris_44': 0.4, 'iris_45': 0.3, 'iris_46': 0.2, 'iris_47': 0.2, 'iris_48': 0.2, 'iris_49': 0.2, 'iris_50': 1.4, 'iris_51': 1.5, 'iris_52': 1.5, 'iris_53': 1.3, 'iris_54': 1.5, 'iris_55': 1.3, 'iris_56': 1.6, 'iris_57': 1.0, 'iris_58': 1.3, 'iris_59': 1.4, 'iris_60': 1.0, 'iris_61': 1.5, 'iris_62': 1.0, 'iris_63': 1.4, 'iris_64': 1.3, 'iris_65': 1.4, 'iris_66': 1.5, 'iris_67': 1.0, 'iris_68': 1.5, 'iris_69': 1.1, 'iris_70': 1.8, 'iris_71': 1.3, 'iris_72': 1.5, 'iris_73': 1.2, 'iris_74': 1.3, 'iris_75': 1.4, 'iris_76': 1.4, 'iris_77': 1.7, 'iris_78': 1.5, 'iris_79': 1.0, 'iris_80': 1.1, 'iris_81': 1.0, 'iris_82': 1.2, 'iris_83': 1.6, 'iris_84': 1.5, 'iris_85': 1.6, 'iris_86': 1.5, 'iris_87': 1.3, 'iris_88': 1.3, 'iris_89': 1.3, 'iris_90': 1.2, 'iris_91': 1.4, 'iris_92': 1.2, 'iris_93': 1.0, 'iris_94': 1.3, 'iris_95': 1.2, 'iris_96': 1.3, 'iris_97': 1.3, 'iris_98': 1.1, 'iris_99': 1.3, 'iris_100': 2.5, 'iris_101': 1.9, 'iris_102': 2.1, 'iris_103': 1.8, 'iris_104': 2.2, 'iris_105': 2.1, 'iris_106': 1.7, 'iris_107': 1.8, 'iris_108': 1.8, 'iris_109': 2.5, 'iris_110': 2.0, 'iris_111': 1.9, 'iris_112': 2.1, 'iris_113': 2.0, 'iris_114': 2.4, 'iris_115': 2.3, 'iris_116': 1.8, 'iris_117': 2.2, 'iris_118': 2.3, 'iris_119': 1.5, 'iris_120': 2.3, 'iris_121': 2.0, 'iris_122': 2.0, 'iris_123': 1.8, 'iris_124': 2.1, 'iris_125': 1.8, 'iris_126': 1.8, 'iris_127': 1.8, 'iris_128': 2.1, 'iris_129': 1.6, 'iris_130': 1.9, 'iris_131': 2.0, 'iris_132': 2.2, 'iris_133': 1.5, 'iris_134': 1.4, 'iris_135': 2.3, 'iris_136': 2.4, 'iris_137': 1.8, 'iris_138': 1.8, 'iris_139': 2.1, 'iris_140': 2.4, 'iris_141': 2.3, 'iris_142': 1.9, 'iris_143': 2.3, 'iris_144': 2.5, 'iris_145': 2.3, 'iris_146': 1.9, 'iris_147': 2.0, 'iris_148': 2.3, 'iris_149': 1.8}})
# Python computation
start_time = time.time()
df_bicor__python = biweight_midcorrelation(X)
# R computation
df_bicor__r = biweight_midcorrelation_r_wrapper(X)
np.allclose(df_bicor__python, df_bicor__r)
Summary
One could write this computation approx. one order of magnitude faster (for the input you specified) with:
import numpy as np
def biweight_midcorrelation(arr):
n, m = arr.shape
arr = arr - np.median(arr, axis=0, keepdims=True)
v = 1 - (arr / (9 * np.median(np.abs(arr), axis=0, keepdims=True))) ** 2
arr = arr * v ** 2 * (v > 0)
norms = np.sqrt(np.sum(arr ** 2, axis=0))
return np.einsum('mi,mj->ij', arr, arr) / norms[:, None] / norms[None, :]
to be bridged to a Pandas dataframe by:
import pandas as pd
def corr_np2pd(df, func):
return pd.DataFrame(func(np.array(df)), index=df.columns, columns=df.columns)
whose usage is:
corr_df = corr_np2pd(df, biweight_midcorrelation)
This could be made even faster by implementing the last computation with Numba.
Introduction
I am not quite sure why you expect broadcasting to be of help in the current code.
Did you perhaps mean vectorizing?
Anyway, I believe that it is possible to write faster code, and a vectorized version of your "old" approach would outperform your current approach.
This could be made even faster using Numba.
There are two practical approaches to your problem:
to manually compute the correlation matrix
to generate a correlation function to be passed to pd.DataFrame.corr()
When doing (1), an explicit looping may not be avoidable without computing unnecessary parts of the correlation matrix.
When doing (2), it will be necessary to compute the auxiliary value of the computation for each (symmetric) pair of the 1D inputs (2 * comb(n, 2) times), as opposed to computing the auxiliary values only once for each of the 1D inputs (n times). For example, for the input specified in the question, one would need to perform n == 4 pre-computations, but, if done in pairwise fashion, this number becomes 2 * comb(4, 2) == 12.
Let us see how can we push the performances in both cases.
Manually Computing the Correlation Matrix
Let us first define a function to serve as a Pandas-to-NumPy bridge:
import numpy as np
import pandas as pd
def corr_np2pd(df, func):
return pd.DataFrame(func(np.array(df)), index=df.columns, columns=df.columns)
The function with explicit looping that is now in the comments belongs to this category and it is reported below as biweight_midcorrelation_pd_OP():
def biweight_midcorrelation_pd_OP(X):
median = X.median()
mad = (X - median).abs().median()
U = (X - median) / (9 * mad)
adjacency = np.square(1 - np.square(U)) * ((1 - U.abs()) > 0)
estimator = (X - median) * adjacency
bicor_matrix = np.empty((X.shape[1], X.shape[1]), dtype=float)
for i, ac in enumerate(estimator):
for j, bc in enumerate(estimator):
a = estimator[ac]
b = estimator[bc]
c = (a * b).sum() / (
np.sqrt(np.square(a).sum()) * np.sqrt(np.square(b).sum()))
bicor_matrix[i, j] = c
bicor_matrix[j, i] = c
return pd.DataFrame(bicor_matrix, index=X.columns, columns=X.columns)
A slightly modified version of that, where the computation is done entirely in NumPy and which should be used with corr_np2pd(), reads:
def biweight_midcorrelation_OP(arr):
n, m = arr.shape
med = np.median(arr, axis=0, keepdims=True)
mad = np.median(np.abs(arr - med), axis=0, keepdims=True)
u = (arr - med) / (9 * mad)
adj = ((1 - u ** 2) ** 2) * ((1 - np.abs(u)) > 0)
est = (arr - med) * adj
result = np.empty((m, m))
for i in range(m):
for j in range(m):
a = est[:, i]
b = est[:, j]
c = (a * b).sum() / (
np.sqrt(np.sum(a ** 2)) * np.sqrt(np.sum(b ** 2)))
result[i, j] = result[j, i] = c
return result
Now, this has some points of improvement:
the intermediate computations can be reduced
the final nested loop could be made more efficient
This last point could be improved with two ways:
by only computing the symmetric indices once, resulting in biweight_midcorrelation_np()
by writing it in vectorized form, resulting in biweight_midcorrelation_npv()
def biweight_midcorrelation_np(arr):
n, m = arr.shape
arr = arr - np.median(arr, axis=0, keepdims=True)
v = 1 - (arr / (9 * np.median(np.abs(arr), axis=0, keepdims=True))) ** 2
arr = arr * v ** 2 * (v > 0)
norms = np.sqrt(np.sum(arr ** 2, axis=0))
result = np.empty((m, m))
np.fill_diagonal(result, 1.0)
for i, j in zip(*np.triu_indices(m, 1)):
result[i, j] = result[j, i] = \
np.sum(arr[:, i] * arr[:, j]) / norms[i] / norms[j]
return result
def biweight_midcorrelation_npv(arr):
n, m = arr.shape
arr = arr - np.median(arr, axis=0, keepdims=True)
v = 1 - (arr / (9 * np.median(np.abs(arr), axis=0, keepdims=True))) ** 2
arr = arr * v ** 2 * (v > 0)
norms = np.sqrt(np.sum(arr ** 2, axis=0))
return np.einsum('mi,mj->ij', arr, arr) / norms[:, None] / norms[None, :]
The first one will be fast as long as m is small, because of the explicit looping.
The second one will generally be fast, but it seems inefficient to compute some of the entries of the matrix twice.
To overcome both issues, one could rewrite the final loop with Numba:
import numba as nb
#nb.jit
def _biweight_midcorrelation_triu_nb(n, m, est, norms, result):
for i in range(m):
for j in range(i + 1, m):
x = 0
for k in range(n):
x += est[k, i] * est[k, j]
result[i, j] = result[j, i] = x / norms[i] / norms[j]
def biweight_midcorrelation_nb(arr):
n, m = arr.shape
arr = arr - np.median(arr, axis=0, keepdims=True)
v = 1 - (arr / (9 * np.median(np.abs(arr), axis=0, keepdims=True))) ** 2
arr = arr * v ** 2 * (v > 0)
norms = np.sqrt(np.sum(arr ** 2, axis=0))
result = np.empty((m, m))
np.fill_diagonal(result, 1.0)
_biweight_midcorrelation_triu_nb(n, m, arr, norms, result)
return result
Pairwise Correlation Function
A slightly modified version of your now proposed approach belongs to this category:
def pairwise_biweight_midcorrelation_OP(a, b):
a_median = np.median(a)
b_median = np.median(b)
a_mad = np.median(np.abs(a - a_median))
b_mad = np.median(np.abs(b - b_median))
u_a = (a - a_median) / (9 * a_mad)
u_b = (b - b_median) / (9 * b_mad)
adj_a = (1 - u_a ** 2) ** 2 * ((1 - np.abs(u_a)) > 0)
adj_b = (1 - u_b ** 2) ** 2 * ((1 - np.abs(u_b)) > 0)
a = (a - a_median) * adj_a
b = (b - b_median) * adj_b
return np.sum(a * b) / (np.sqrt(np.sum(a ** 2)) * np.sqrt(np.sum(b ** 2)))
This may be written a bit more concisely, using similar simplifications as above, resuling in:
def pairwise_biweight_midcorrelation_opt(a, b):
a = a - np.median(a)
b = b - np.median(b)
v_a = 1 - (a / (9 * np.median(np.abs(a)))) ** 2
v_b = 1 - (b / (9 * np.median(np.abs(b)))) ** 2
a = a * v_a ** 2 * (v_a > 0)
b = b * v_b ** 2 * (v_b > 0)
return np.sum(a * b) / (np.sqrt(np.sum(a ** 2)) * np.sqrt(np.sum(b ** 2)))
The last operation is performing summation over a and b three times, but it could actually be done in a single loop, which could be again made fast with Numba:
#nb.jit
def pairwise_biweight_midcorrelation_nb(a, b):
n = a.size
a = a - np.median(a)
b = b - np.median(b)
v_a = 1 - (a / (9 * np.median(np.abs(a)))) ** 2
v_b = 1 - (b / (9 * np.median(np.abs(b)))) ** 2
a = (v_a > 0) * a * v_a ** 2
b = (v_b > 0) * b * v_b ** 2
s_ab = s_aa = s_bb = 0
for i in range(n):
s_ab += a[i] * b[i]
s_aa += a[i] * a[i]
s_bb += b[i] * b[i]
return s_ab / np.sqrt(s_aa) / np.sqrt(s_bb)
But there is no simple way of avoiding to perform the pre-computations 2 * comb(n, 2) times instead of n times.
The other side of the story is that this class of approaches requires less memory as only two 1D array are considered at each iteration.
Testing
For the suggested input:
import pandas as pd
df = pd.DataFrame({'sepal_length': {'iris_0': 5.1, 'iris_1': 4.9, 'iris_2': 4.7, 'iris_3': 4.6, 'iris_4': 5.0, 'iris_5': 5.4, 'iris_6': 4.6, 'iris_7': 5.0, 'iris_8': 4.4, 'iris_9': 4.9, 'iris_10': 5.4, 'iris_11': 4.8, 'iris_12': 4.8, 'iris_13': 4.3, 'iris_14': 5.8, 'iris_15': 5.7, 'iris_16': 5.4, 'iris_17': 5.1, 'iris_18': 5.7, 'iris_19': 5.1, 'iris_20': 5.4, 'iris_21': 5.1, 'iris_22': 4.6, 'iris_23': 5.1, 'iris_24': 4.8, 'iris_25': 5.0, 'iris_26': 5.0, 'iris_27': 5.2, 'iris_28': 5.2, 'iris_29': 4.7, 'iris_30': 4.8, 'iris_31': 5.4, 'iris_32': 5.2, 'iris_33': 5.5, 'iris_34': 4.9, 'iris_35': 5.0, 'iris_36': 5.5, 'iris_37': 4.9, 'iris_38': 4.4, 'iris_39': 5.1, 'iris_40': 5.0, 'iris_41': 4.5, 'iris_42': 4.4, 'iris_43': 5.0, 'iris_44': 5.1, 'iris_45': 4.8, 'iris_46': 5.1, 'iris_47': 4.6, 'iris_48': 5.3, 'iris_49': 5.0, 'iris_50': 7.0, 'iris_51': 6.4, 'iris_52': 6.9, 'iris_53': 5.5, 'iris_54': 6.5, 'iris_55': 5.7, 'iris_56': 6.3, 'iris_57': 4.9, 'iris_58': 6.6, 'iris_59': 5.2, 'iris_60': 5.0, 'iris_61': 5.9, 'iris_62': 6.0, 'iris_63': 6.1, 'iris_64': 5.6, 'iris_65': 6.7, 'iris_66': 5.6, 'iris_67': 5.8, 'iris_68': 6.2, 'iris_69': 5.6, 'iris_70': 5.9, 'iris_71': 6.1, 'iris_72': 6.3, 'iris_73': 6.1, 'iris_74': 6.4, 'iris_75': 6.6, 'iris_76': 6.8, 'iris_77': 6.7, 'iris_78': 6.0, 'iris_79': 5.7, 'iris_80': 5.5, 'iris_81': 5.5, 'iris_82': 5.8, 'iris_83': 6.0, 'iris_84': 5.4, 'iris_85': 6.0, 'iris_86': 6.7, 'iris_87': 6.3, 'iris_88': 5.6, 'iris_89': 5.5, 'iris_90': 5.5, 'iris_91': 6.1, 'iris_92': 5.8, 'iris_93': 5.0, 'iris_94': 5.6, 'iris_95': 5.7, 'iris_96': 5.7, 'iris_97': 6.2, 'iris_98': 5.1, 'iris_99': 5.7, 'iris_100': 6.3, 'iris_101': 5.8, 'iris_102': 7.1, 'iris_103': 6.3, 'iris_104': 6.5, 'iris_105': 7.6, 'iris_106': 4.9, 'iris_107': 7.3, 'iris_108': 6.7, 'iris_109': 7.2, 'iris_110': 6.5, 'iris_111': 6.4, 'iris_112': 6.8, 'iris_113': 5.7, 'iris_114': 5.8, 'iris_115': 6.4, 'iris_116': 6.5, 'iris_117': 7.7, 'iris_118': 7.7, 'iris_119': 6.0, 'iris_120': 6.9, 'iris_121': 5.6, 'iris_122': 7.7, 'iris_123': 6.3, 'iris_124': 6.7, 'iris_125': 7.2, 'iris_126': 6.2, 'iris_127': 6.1, 'iris_128': 6.4, 'iris_129': 7.2, 'iris_130': 7.4, 'iris_131': 7.9, 'iris_132': 6.4, 'iris_133': 6.3, 'iris_134': 6.1, 'iris_135': 7.7, 'iris_136': 6.3, 'iris_137': 6.4, 'iris_138': 6.0, 'iris_139': 6.9, 'iris_140': 6.7, 'iris_141': 6.9, 'iris_142': 5.8, 'iris_143': 6.8, 'iris_144': 6.7, 'iris_145': 6.7, 'iris_146': 6.3, 'iris_147': 6.5, 'iris_148': 6.2, 'iris_149': 5.9}, 'sepal_width': {'iris_0': 3.5, 'iris_1': 3.0, 'iris_2': 3.2, 'iris_3': 3.1, 'iris_4': 3.6, 'iris_5': 3.9, 'iris_6': 3.4, 'iris_7': 3.4, 'iris_8': 2.9, 'iris_9': 3.1, 'iris_10': 3.7, 'iris_11': 3.4, 'iris_12': 3.0, 'iris_13': 3.0, 'iris_14': 4.0, 'iris_15': 4.4, 'iris_16': 3.9, 'iris_17': 3.5, 'iris_18': 3.8, 'iris_19': 3.8, 'iris_20': 3.4, 'iris_21': 3.7, 'iris_22': 3.6, 'iris_23': 3.3, 'iris_24': 3.4, 'iris_25': 3.0, 'iris_26': 3.4, 'iris_27': 3.5, 'iris_28': 3.4, 'iris_29': 3.2, 'iris_30': 3.1, 'iris_31': 3.4, 'iris_32': 4.1, 'iris_33': 4.2, 'iris_34': 3.1, 'iris_35': 3.2, 'iris_36': 3.5, 'iris_37': 3.6, 'iris_38': 3.0, 'iris_39': 3.4, 'iris_40': 3.5, 'iris_41': 2.3, 'iris_42': 3.2, 'iris_43': 3.5, 'iris_44': 3.8, 'iris_45': 3.0, 'iris_46': 3.8, 'iris_47': 3.2, 'iris_48': 3.7, 'iris_49': 3.3, 'iris_50': 3.2, 'iris_51': 3.2, 'iris_52': 3.1, 'iris_53': 2.3, 'iris_54': 2.8, 'iris_55': 2.8, 'iris_56': 3.3, 'iris_57': 2.4, 'iris_58': 2.9, 'iris_59': 2.7, 'iris_60': 2.0, 'iris_61': 3.0, 'iris_62': 2.2, 'iris_63': 2.9, 'iris_64': 2.9, 'iris_65': 3.1, 'iris_66': 3.0, 'iris_67': 2.7, 'iris_68': 2.2, 'iris_69': 2.5, 'iris_70': 3.2, 'iris_71': 2.8, 'iris_72': 2.5, 'iris_73': 2.8, 'iris_74': 2.9, 'iris_75': 3.0, 'iris_76': 2.8, 'iris_77': 3.0, 'iris_78': 2.9, 'iris_79': 2.6, 'iris_80': 2.4, 'iris_81': 2.4, 'iris_82': 2.7, 'iris_83': 2.7, 'iris_84': 3.0, 'iris_85': 3.4, 'iris_86': 3.1, 'iris_87': 2.3, 'iris_88': 3.0, 'iris_89': 2.5, 'iris_90': 2.6, 'iris_91': 3.0, 'iris_92': 2.6, 'iris_93': 2.3, 'iris_94': 2.7, 'iris_95': 3.0, 'iris_96': 2.9, 'iris_97': 2.9, 'iris_98': 2.5, 'iris_99': 2.8, 'iris_100': 3.3, 'iris_101': 2.7, 'iris_102': 3.0, 'iris_103': 2.9, 'iris_104': 3.0, 'iris_105': 3.0, 'iris_106': 2.5, 'iris_107': 2.9, 'iris_108': 2.5, 'iris_109': 3.6, 'iris_110': 3.2, 'iris_111': 2.7, 'iris_112': 3.0, 'iris_113': 2.5, 'iris_114': 2.8, 'iris_115': 3.2, 'iris_116': 3.0, 'iris_117': 3.8, 'iris_118': 2.6, 'iris_119': 2.2, 'iris_120': 3.2, 'iris_121': 2.8, 'iris_122': 2.8, 'iris_123': 2.7, 'iris_124': 3.3, 'iris_125': 3.2, 'iris_126': 2.8, 'iris_127': 3.0, 'iris_128': 2.8, 'iris_129': 3.0, 'iris_130': 2.8, 'iris_131': 3.8, 'iris_132': 2.8, 'iris_133': 2.8, 'iris_134': 2.6, 'iris_135': 3.0, 'iris_136': 3.4, 'iris_137': 3.1, 'iris_138': 3.0, 'iris_139': 3.1, 'iris_140': 3.1, 'iris_141': 3.1, 'iris_142': 2.7, 'iris_143': 3.2, 'iris_144': 3.3, 'iris_145': 3.0, 'iris_146': 2.5, 'iris_147': 3.0, 'iris_148': 3.4, 'iris_149': 3.0}, 'petal_length': {'iris_0': 1.4, 'iris_1': 1.4, 'iris_2': 1.3, 'iris_3': 1.5, 'iris_4': 1.4, 'iris_5': 1.7, 'iris_6': 1.4, 'iris_7': 1.5, 'iris_8': 1.4, 'iris_9': 1.5, 'iris_10': 1.5, 'iris_11': 1.6, 'iris_12': 1.4, 'iris_13': 1.1, 'iris_14': 1.2, 'iris_15': 1.5, 'iris_16': 1.3, 'iris_17': 1.4, 'iris_18': 1.7, 'iris_19': 1.5, 'iris_20': 1.7, 'iris_21': 1.5, 'iris_22': 1.0, 'iris_23': 1.7, 'iris_24': 1.9, 'iris_25': 1.6, 'iris_26': 1.6, 'iris_27': 1.5, 'iris_28': 1.4, 'iris_29': 1.6, 'iris_30': 1.6, 'iris_31': 1.5, 'iris_32': 1.5, 'iris_33': 1.4, 'iris_34': 1.5, 'iris_35': 1.2, 'iris_36': 1.3, 'iris_37': 1.4, 'iris_38': 1.3, 'iris_39': 1.5, 'iris_40': 1.3, 'iris_41': 1.3, 'iris_42': 1.3, 'iris_43': 1.6, 'iris_44': 1.9, 'iris_45': 1.4, 'iris_46': 1.6, 'iris_47': 1.4, 'iris_48': 1.5, 'iris_49': 1.4, 'iris_50': 4.7, 'iris_51': 4.5, 'iris_52': 4.9, 'iris_53': 4.0, 'iris_54': 4.6, 'iris_55': 4.5, 'iris_56': 4.7, 'iris_57': 3.3, 'iris_58': 4.6, 'iris_59': 3.9, 'iris_60': 3.5, 'iris_61': 4.2, 'iris_62': 4.0, 'iris_63': 4.7, 'iris_64': 3.6, 'iris_65': 4.4, 'iris_66': 4.5, 'iris_67': 4.1, 'iris_68': 4.5, 'iris_69': 3.9, 'iris_70': 4.8, 'iris_71': 4.0, 'iris_72': 4.9, 'iris_73': 4.7, 'iris_74': 4.3, 'iris_75': 4.4, 'iris_76': 4.8, 'iris_77': 5.0, 'iris_78': 4.5, 'iris_79': 3.5, 'iris_80': 3.8, 'iris_81': 3.7, 'iris_82': 3.9, 'iris_83': 5.1, 'iris_84': 4.5, 'iris_85': 4.5, 'iris_86': 4.7, 'iris_87': 4.4, 'iris_88': 4.1, 'iris_89': 4.0, 'iris_90': 4.4, 'iris_91': 4.6, 'iris_92': 4.0, 'iris_93': 3.3, 'iris_94': 4.2, 'iris_95': 4.2, 'iris_96': 4.2, 'iris_97': 4.3, 'iris_98': 3.0, 'iris_99': 4.1, 'iris_100': 6.0, 'iris_101': 5.1, 'iris_102': 5.9, 'iris_103': 5.6, 'iris_104': 5.8, 'iris_105': 6.6, 'iris_106': 4.5, 'iris_107': 6.3, 'iris_108': 5.8, 'iris_109': 6.1, 'iris_110': 5.1, 'iris_111': 5.3, 'iris_112': 5.5, 'iris_113': 5.0, 'iris_114': 5.1, 'iris_115': 5.3, 'iris_116': 5.5, 'iris_117': 6.7, 'iris_118': 6.9, 'iris_119': 5.0, 'iris_120': 5.7, 'iris_121': 4.9, 'iris_122': 6.7, 'iris_123': 4.9, 'iris_124': 5.7, 'iris_125': 6.0, 'iris_126': 4.8, 'iris_127': 4.9, 'iris_128': 5.6, 'iris_129': 5.8, 'iris_130': 6.1, 'iris_131': 6.4, 'iris_132': 5.6, 'iris_133': 5.1, 'iris_134': 5.6, 'iris_135': 6.1, 'iris_136': 5.6, 'iris_137': 5.5, 'iris_138': 4.8, 'iris_139': 5.4, 'iris_140': 5.6, 'iris_141': 5.1, 'iris_142': 5.1, 'iris_143': 5.9, 'iris_144': 5.7, 'iris_145': 5.2, 'iris_146': 5.0, 'iris_147': 5.2, 'iris_148': 5.4, 'iris_149': 5.1}, 'petal_width': {'iris_0': 0.2, 'iris_1': 0.2, 'iris_2': 0.2, 'iris_3': 0.2, 'iris_4': 0.2, 'iris_5': 0.4, 'iris_6': 0.3, 'iris_7': 0.2, 'iris_8': 0.2, 'iris_9': 0.1, 'iris_10': 0.2, 'iris_11': 0.2, 'iris_12': 0.1, 'iris_13': 0.1, 'iris_14': 0.2, 'iris_15': 0.4, 'iris_16': 0.4, 'iris_17': 0.3, 'iris_18': 0.3, 'iris_19': 0.3, 'iris_20': 0.2, 'iris_21': 0.4, 'iris_22': 0.2, 'iris_23': 0.5, 'iris_24': 0.2, 'iris_25': 0.2, 'iris_26': 0.4, 'iris_27': 0.2, 'iris_28': 0.2, 'iris_29': 0.2, 'iris_30': 0.2, 'iris_31': 0.4, 'iris_32': 0.1, 'iris_33': 0.2, 'iris_34': 0.2, 'iris_35': 0.2, 'iris_36': 0.2, 'iris_37': 0.1, 'iris_38': 0.2, 'iris_39': 0.2, 'iris_40': 0.3, 'iris_41': 0.3, 'iris_42': 0.2, 'iris_43': 0.6, 'iris_44': 0.4, 'iris_45': 0.3, 'iris_46': 0.2, 'iris_47': 0.2, 'iris_48': 0.2, 'iris_49': 0.2, 'iris_50': 1.4, 'iris_51': 1.5, 'iris_52': 1.5, 'iris_53': 1.3, 'iris_54': 1.5, 'iris_55': 1.3, 'iris_56': 1.6, 'iris_57': 1.0, 'iris_58': 1.3, 'iris_59': 1.4, 'iris_60': 1.0, 'iris_61': 1.5, 'iris_62': 1.0, 'iris_63': 1.4, 'iris_64': 1.3, 'iris_65': 1.4, 'iris_66': 1.5, 'iris_67': 1.0, 'iris_68': 1.5, 'iris_69': 1.1, 'iris_70': 1.8, 'iris_71': 1.3, 'iris_72': 1.5, 'iris_73': 1.2, 'iris_74': 1.3, 'iris_75': 1.4, 'iris_76': 1.4, 'iris_77': 1.7, 'iris_78': 1.5, 'iris_79': 1.0, 'iris_80': 1.1, 'iris_81': 1.0, 'iris_82': 1.2, 'iris_83': 1.6, 'iris_84': 1.5, 'iris_85': 1.6, 'iris_86': 1.5, 'iris_87': 1.3, 'iris_88': 1.3, 'iris_89': 1.3, 'iris_90': 1.2, 'iris_91': 1.4, 'iris_92': 1.2, 'iris_93': 1.0, 'iris_94': 1.3, 'iris_95': 1.2, 'iris_96': 1.3, 'iris_97': 1.3, 'iris_98': 1.1, 'iris_99': 1.3, 'iris_100': 2.5, 'iris_101': 1.9, 'iris_102': 2.1, 'iris_103': 1.8, 'iris_104': 2.2, 'iris_105': 2.1, 'iris_106': 1.7, 'iris_107': 1.8, 'iris_108': 1.8, 'iris_109': 2.5, 'iris_110': 2.0, 'iris_111': 1.9, 'iris_112': 2.1, 'iris_113': 2.0, 'iris_114': 2.4, 'iris_115': 2.3, 'iris_116': 1.8, 'iris_117': 2.2, 'iris_118': 2.3, 'iris_119': 1.5, 'iris_120': 2.3, 'iris_121': 2.0, 'iris_122': 2.0, 'iris_123': 1.8, 'iris_124': 2.1, 'iris_125': 1.8, 'iris_126': 1.8, 'iris_127': 1.8, 'iris_128': 2.1, 'iris_129': 1.6, 'iris_130': 1.9, 'iris_131': 2.0, 'iris_132': 2.2, 'iris_133': 1.5, 'iris_134': 1.4, 'iris_135': 2.3, 'iris_136': 2.4, 'iris_137': 1.8, 'iris_138': 1.8, 'iris_139': 2.1, 'iris_140': 2.4, 'iris_141': 2.3, 'iris_142': 1.9, 'iris_143': 2.3, 'iris_144': 2.5, 'iris_145': 2.3, 'iris_146': 1.9, 'iris_147': 2.0, 'iris_148': 2.3, 'iris_149': 1.8}})
we obtain:
print(np.all(np.isclose(biweight_midcorrelation_pd_OP(df), result)))
# True
print(np.all(np.isclose(corr_np2pd(df, biweight_midcorrelation_OP), result)))
# True
print(np.all(np.isclose(corr_np2pd(df, biweight_midcorrelation_np), result)))
# True
print(np.all(np.isclose(corr_np2pd(df, biweight_midcorrelation_npv), result)))
# True
print(np.all(np.isclose(corr_np2pd(df, biweight_midcorrelation_nb), result)))
# True
print(np.all(np.isclose(df.corr(method=pairwise_biweight_midcorrelation_OP), result)))
# True
print(np.all(np.isclose(df.corr(method=pairwise_biweight_midcorrelation_opt), result)))
# True
print(np.all(np.isclose(df.corr(method=pairwise_biweight_midcorrelation_nb), result)))
# True
Benchmarks
%timeit biweight_midcorrelation_pd_OP(df)
# 10 loops, best of 3: 22.1 ms per loop
%timeit corr_np2pd(df, biweight_midcorrelation_OP)
# 1000 loops, best of 3: 682 µs per loop
%timeit corr_np2pd(df, biweight_midcorrelation_np)
# 1000 loops, best of 3: 422 µs per loop
%timeit corr_np2pd(df, biweight_midcorrelation_npv)
# 1000 loops, best of 3: 341 µs per loop
%timeit corr_np2pd(df, biweight_midcorrelation_nb)
# 1000 loops, best of 3: 325 µs per loop
%timeit df.corr(method=pairwise_biweight_midcorrelation_OP)
# 100 loops, best of 3: 1.96 ms per loop
%timeit df.corr(method=pairwise_biweight_midcorrelation_opt)
# 100 loops, best of 3: 1.83 ms per loop
%timeit df.corr(method=pairwise_biweight_midcorrelation_nb)
# 1000 loops, best of 3: 506 µs per loop
These results would indicate the Numba-based approach to be the fastest, closely followed by the NumPy-vectorized version of your original approach.
Note that going from a Pandas-based computation to a pure NumPy-based approach (even with explicit looping) we get almost 30x speed factor.
And vectorizing the two for loops buys us another approx. 2x factor.
The pd.DataFrame.corr() based approach(es) are, when not using Numba, approx. 4x slower than your original approach rewritten in NumPy, so be careful even if you do not see explicit looping!
The Numba accelerated pairwise_biweight_midcorrelation_nb() gives a significant boost to this family of approaches, but it cannot possibly avoid the overhead of the pre-computations.
Final warning: all these benchmarks should be taken with a grain of salt!
(EDITED to include a Numba-based approach to use with pd.DataFrame.corr()).
With a copy-n-paste of your X:
In [26]: X
Out[26]:
sepal_length sepal_width petal_length petal_width
iris_0 5.1 3.5 1.4 0.2
iris_1 4.9 3.0 1.4 0.2
iris_2 4.7 3.2 1.3 0.2
iris_3 4.6 3.1 1.5 0.2
iris_4 5.0 3.6 1.4 0.2
... ... ... ... ...
iris_145 6.7 3.0 5.2 2.3
iris_146 6.3 2.5 5.0 1.9
iris_147 6.5 3.0 5.2 2.0
iris_148 6.2 3.4 5.4 2.3
iris_149 5.9 3.0 5.1 1.8
[150 rows x 4 columns]
and using it:
In [29]: X.corr(method=_biweight_midcorrelation)
Out[29]:
sepal_length sepal_width petal_length petal_width
sepal_length 1.000000 -0.134780 0.831958 0.818575
sepal_width -0.134780 1.000000 -0.430312 -0.374034
petal_length 0.831958 -0.430312 1.000000 0.952285
petal_width 0.818575 -0.374034 0.952285 1.000000
In [30]: X.corr?
In [31]: _biweight_midcorrelation(X['sepal_length'],X['sepal_width'])
Out[31]: -0.13477989268659313
In [32]: _biweight_midcorrelation(X['sepal_length'],X['petal_length'])
Out[32]: 0.831958204443503
In _biweight_midcorrelation(a, b), a and b are Series, the same size. So all their derived arrays have the same shape, and (a_item * b_item) works just (by broadcasting - the rules of broadcasting apply to 2 1d arrays). I don't see any need for 'outer products'.
Problem
For a computation engineering model, I want to do a grid search for all feasible parameter combinations. Each parameter has a certain possibility range, e.g. (0 … 100) and the parameter combination must fulfil the condition a+b+c=100. An example:
ranges = {
'a': (95, 99),
'b': (1, 4),
'c': (1, 2)}
increment = 1.0
target = 100.0
So the combinations that fulfil the condition a+b+c=100 are:
[(95, 4, 1), (95, 3, 2), (96, 2, 2), (96, 3, 1), (97, 1, 2), (97, 2, 1), (98, 1, 1)]
This algorithm should run with any number of parameters, range lengths, and increments.
My solutions (so far)
The solutions I have come up with are all brute-forcing the problem. That means calculating all combinations and then discarding the ones that do not fulfil the given condition:
def solution1(ranges, increment, target):
combinations = []
for parameter in ranges:
combinations.append(list(np.arange(ranges[parameter][0], ranges[parameter][1], increment)))
# np.arange() is exclusive of the upper bound, let's fix that
if combinations[-1][-1] != ranges[parameter][1]:
combinations[-1].append(ranges[parameter][1])
combinations = list(itertools.product(*combinations))
df = pd.DataFrame(combinations, columns=ranges.keys())
# using np.isclose() so that the algorithm works for floats
return df[np.isclose(df.sum(axis=1), target)]
Since I ran into RAM problems with solution1(), I used itertools.product as an iterator.
def solution2(ranges, increment, target):
combinations = []
for parameter in ranges:
combinations.append(list(np.arange(ranges[parameter][0], ranges[parameter][1], increment)))
# np.arange() is exclusive of the upper bound, let's fix that
if combinations[-1][-1] != ranges[parameter][1]:
combinations[-1].append(ranges[parameter][1])
result = []
for combination in itertools.product(*combinations):
# using np.isclose() so that the algorithm works for floats
if np.isclose(sum(combination), target):
result.append(combination)
df = pd.DataFrame(result, columns=ranges.keys())
return df
However, this quickly takes a few days to compute. Hence, both solutions are not viable for large number of parameters and ranges. For instance, one set that I am trying to solve is (already unpacked combinations variable):
[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], [0.0, 1.0, 2.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0], [0.0]]
This results in memory use of >40 GB for solution1() and calculation time >400 hours for solution2().
Question
Do you see a solution that is either faster or more intelligent, i.e. not trying to brute-force the problem?
P.S.: I am not 100% sure if this question would be a better fit on one of the other Stackexchange sites. Please suggest in the comments if you think it should be moved and I will delete it here.
Here is a recursive solution:
a = [95, 100]
b = [1, 4]
c = [1, 2]
Params = (a, b, c)
def GetValidParamValues(Params, constriantSum, prevVals):
validParamValues = []
if (len(Params) == 1):
if (constriantSum >= Params[0][0] and constriantSum <= Params[0][1]):
validParamValues.append(constriantSum)
for v in validParamValues:
print(prevVals + v)
return
sumOfLowParams = sum([Params[x][0] for x in range(1, len(Params))])
sumOfHighParams = sum([Params[x][1] for x in range(1, len(Params))])
lowEnd = max(Params[0][0], constriantSum - sumOfHighParams)
highEnd = min(Params[0][1], constriantSum - sumOfLowParams) + 1
if (len(Params) == 2):
for av in range(lowEnd, highEnd):
bv = constriantSum - av
if (bv <= Params[1][1]):
validParamValues.append([av, bv])
for v in validParamValues:
print(prevVals + v)
return
for av in range(lowEnd, highEnd):
nexPrevVals = prevVals + [av]
subSeParams = Params[1:]
GetValidParamValues(subSeParams, constriantSum - av, nexPrevVals)
GetValidParamValues(Params, 100)
The idea is that if there were 2 parameters, a and b, we could list all the valid pairs by passing through the values of a, and taking (ai, S - ai) and just checking if S-ai is a valid value for b.
This is improved on since we can calculate ahead of time which values of ai will make S-ai a valid value for b, so we never check values that don't work.
When the number of params is more than 2, we can again look at every valid value of ai, and we know the sum of the other numbers must be S - ai. So the only thing we need is every possible way for the other numbers to add to S - ai, which is the same problem with one fewer parameter. So by using recursion we can get it go all the way down to size 2 and solve it.
Edit: Due to errors in my code i updated with my oldest, but working code
I get a list of speed recordings from a database, and I want to find the max speed in that list. Sounds easy enough, but I got some requirements for any max speed to count:
If the max speed is over a certain level, it has to have more than a certain number of records to be recognized as maximum speed. The reason for this logic is that I want the max speed under normal conditions, not just an error or one time occurrence. I also have a constraint that a speed has to be over a certain limit to be counted, for the same reason.
Here is the example on a speed array:
v = [8.0, 1.3, 0.7, 0.8, 0.9, 1.1, 14.9, 14.0, 14.1, 14.2, 14.3, 13.8, 13.9, 13.7, 13.6, 13.5, 13.4, 15.7, 15.8, 15.0, 15.3, 15.4, 15.5, 15.6, 15.2, 12.8, 12.7, 12.6, 8.7, 8.8, 8.6, 9.0, 8.5, 8.4, 8.3, 0.1, 0.0, 16.4, 16.5, 16.7, 16.8, 17.0, 17.1, 17.8, 17.7, 17.6, 17.4, 17.5, 17.3, 17.9, 18.2, 18.3, 18.1, 18.0, 18.4, 18.5, 18.6, 19.0, 19.1, 18.9, 19.2, 19.3, 19.9, 20.1, 19.8, 20.0, 19.7, 19.6, 19.5, 20.2, 20.3, 18.7, 18.8, 17.2, 16.9, 11.5, 11.2, 11.3, 11.4, 7.1, 12.9, 14.4, 13.1, 13.2, 12.5, 12.1, 12.2, 13.0, 0.2, 3.6, 7.4, 4.6, 4.5, 4.3, 4.0, 9.4, 9.6, 9.7, 5.8, 5.7, 7.3, 2.1, 0.4, 0.3, 16.1, 11.9, 12.0, 11.7, 11.8, 10.0, 10.1, 9.8, 15.1, 14.7, 14.8, 10.2, 10.3, 1.2, 9.9, 1.9, 3.4, 14.6, 0.6, 5.1, 5.2, 7.5, 19.4, 10.7, 10.8, 10.9, 0.5, 16.3, 16.2, 16.0, 16.6, 12.4, 11.0, 1.7, 1.6, 2.4, 11.6, 3.9, 3.8, 14.5, 11.1]
This is my code to find what I define as the true maximum speed:
from collections import Counter
while max(speeds)>30:
speeds.remove(max(speeds))
nwsp = []
for s in speeds:
nwsp.append(np.floor(s))
count = Counter(nwsp)
while speeds and max(speeds)>14 and count[np.floor(max(speeds))]<10:
speeds.remove(max(speeds))
while speeds and max(speeds)<5:
speeds.remove(max(speeds))
if speeds:
print max(speeds)
return max(speeds)
else:
return False
Result with v as shown over: 19.9
The reason that i make the nwsp is that it doesn't matter for me if f.ex 19.6 is only found 9 times - if any number inside the same integer, f.ex 19.7 is found 3 times as well, then 19.6 will be valid.
How can I rewrite/optimize this code so the selection process is quicker? I already removed the max(speeds) and instead sorted the list and referenced the largest element using speeds[-1].
Sorry for not adding any unit to my speeds.
Your code is just slow because you call max and remove over and over and over again and each of those calls costs time proportional to the length of the list. Any reasonable solution will be much faster.
If you know that False can't happen, then this suffices:
speeds = [8.0, 1.3, 0.7, 0.8, 0.9, 1.1, 14.9, 14.0, 14.1, 14.2, 14.3, 13.8, 13.9, 13.7, 13.6, 13.5, 13.4, 15.7, 15.8, 15.0, 15.3, 15.4, 15.5, 15.6, 15.2, 12.8, 12.7, 12.6, 8.7, 8.8, 8.6, 9.0, 8.5, 8.4, 8.3, 0.1, 0.0, 16.4, 16.5, 16.7, 16.8, 17.0, 17.1, 17.8, 17.7, 17.6, 17.4, 17.5, 17.3, 17.9, 18.2, 18.3, 18.1, 18.0, 18.4, 18.5, 18.6, 19.0, 19.1, 18.9, 19.2, 19.3, 19.9, 20.1, 19.8, 20.0, 19.7, 19.6, 19.5, 20.2, 20.3, 18.7, 18.8, 17.2, 16.9, 11.5, 11.2, 11.3, 11.4, 7.1, 12.9, 14.4, 13.1, 13.2, 12.5, 12.1, 12.2, 13.0, 0.2, 3.6, 7.4, 4.6, 4.5, 4.3, 4.0, 9.4, 9.6, 9.7, 5.8, 5.7, 7.3, 2.1, 0.4, 0.3, 16.1, 11.9, 12.0, 11.7, 11.8, 10.0, 10.1, 9.8, 15.1, 14.7, 14.8, 10.2, 10.3, 1.2, 9.9, 1.9, 3.4, 14.6, 0.6, 5.1, 5.2, 7.5, 19.4, 10.7, 10.8, 10.9, 0.5, 16.3, 16.2, 16.0, 16.6, 12.4, 11.0, 1.7, 1.6, 2.4, 11.6, 3.9, 3.8, 14.5, 11.1]
from collections import Counter
count = Counter(map(int, speeds))
print max(s for s in speeds
if 5 <= s <= 30 and (s <= 14 or count[int(s)] >= 10))
If the False case can happen, this would be one way:
speeds = [8.0, 1.3, 0.7, 0.8, 0.9, 1.1, 14.9, 14.0, 14.1, 14.2, 14.3, 13.8, 13.9, 13.7, 13.6, 13.5, 13.4, 15.7, 15.8, 15.0, 15.3, 15.4, 15.5, 15.6, 15.2, 12.8, 12.7, 12.6, 8.7, 8.8, 8.6, 9.0, 8.5, 8.4, 8.3, 0.1, 0.0, 16.4, 16.5, 16.7, 16.8, 17.0, 17.1, 17.8, 17.7, 17.6, 17.4, 17.5, 17.3, 17.9, 18.2, 18.3, 18.1, 18.0, 18.4, 18.5, 18.6, 19.0, 19.1, 18.9, 19.2, 19.3, 19.9, 20.1, 19.8, 20.0, 19.7, 19.6, 19.5, 20.2, 20.3, 18.7, 18.8, 17.2, 16.9, 11.5, 11.2, 11.3, 11.4, 7.1, 12.9, 14.4, 13.1, 13.2, 12.5, 12.1, 12.2, 13.0, 0.2, 3.6, 7.4, 4.6, 4.5, 4.3, 4.0, 9.4, 9.6, 9.7, 5.8, 5.7, 7.3, 2.1, 0.4, 0.3, 16.1, 11.9, 12.0, 11.7, 11.8, 10.0, 10.1, 9.8, 15.1, 14.7, 14.8, 10.2, 10.3, 1.2, 9.9, 1.9, 3.4, 14.6, 0.6, 5.1, 5.2, 7.5, 19.4, 10.7, 10.8, 10.9, 0.5, 16.3, 16.2, 16.0, 16.6, 12.4, 11.0, 1.7, 1.6, 2.4, 11.6, 3.9, 3.8, 14.5, 11.1]
from collections import Counter
count = Counter(map(int, speeds))
valids = [s for s in speeds
if 5 <= s <= 30 and (s <= 14 or count[int(s)] >= 10)]
print max(valids) if valids else False
Or sort and use next, which can take your False as default:
speeds = [8.0, 1.3, 0.7, 0.8, 0.9, 1.1, 14.9, 14.0, 14.1, 14.2, 14.3, 13.8, 13.9, 13.7, 13.6, 13.5, 13.4, 15.7, 15.8, 15.0, 15.3, 15.4, 15.5, 15.6, 15.2, 12.8, 12.7, 12.6, 8.7, 8.8, 8.6, 9.0, 8.5, 8.4, 8.3, 0.1, 0.0, 16.4, 16.5, 16.7, 16.8, 17.0, 17.1, 17.8, 17.7, 17.6, 17.4, 17.5, 17.3, 17.9, 18.2, 18.3, 18.1, 18.0, 18.4, 18.5, 18.6, 19.0, 19.1, 18.9, 19.2, 19.3, 19.9, 20.1, 19.8, 20.0, 19.7, 19.6, 19.5, 20.2, 20.3, 18.7, 18.8, 17.2, 16.9, 11.5, 11.2, 11.3, 11.4, 7.1, 12.9, 14.4, 13.1, 13.2, 12.5, 12.1, 12.2, 13.0, 0.2, 3.6, 7.4, 4.6, 4.5, 4.3, 4.0, 9.4, 9.6, 9.7, 5.8, 5.7, 7.3, 2.1, 0.4, 0.3, 16.1, 11.9, 12.0, 11.7, 11.8, 10.0, 10.1, 9.8, 15.1, 14.7, 14.8, 10.2, 10.3, 1.2, 9.9, 1.9, 3.4, 14.6, 0.6, 5.1, 5.2, 7.5, 19.4, 10.7, 10.8, 10.9, 0.5, 16.3, 16.2, 16.0, 16.6, 12.4, 11.0, 1.7, 1.6, 2.4, 11.6, 3.9, 3.8, 14.5, 11.1]
count = Counter(map(int, speeds))
print next((s for s in reversed(sorted(speeds))
if 5 <= s <= 30 and (s <= 14 or count[int(s)] >= 10)),
False)
Instead of Counter, you could also use groupby:
speeds = [8.0, 1.3, 0.7, 0.8, 0.9, 1.1, 14.9, 14.0, 14.1, 14.2, 14.3, 13.8, 13.9, 13.7, 13.6, 13.5, 13.4, 15.7, 15.8, 15.0, 15.3, 15.4, 15.5, 15.6, 15.2, 12.8, 12.7, 12.6, 8.7, 8.8, 8.6, 9.0, 8.5, 8.4, 8.3, 0.1, 0.0, 16.4, 16.5, 16.7, 16.8, 17.0, 17.1, 17.8, 17.7, 17.6, 17.4, 17.5, 17.3, 17.9, 18.2, 18.3, 18.1, 18.0, 18.4, 18.5, 18.6, 19.0, 19.1, 18.9, 19.2, 19.3, 19.9, 20.1, 19.8, 20.0, 19.7, 19.6, 19.5, 20.2, 20.3, 18.7, 18.8, 17.2, 16.9, 11.5, 11.2, 11.3, 11.4, 7.1, 12.9, 14.4, 13.1, 13.2, 12.5, 12.1, 12.2, 13.0, 0.2, 3.6, 7.4, 4.6, 4.5, 4.3, 4.0, 9.4, 9.6, 9.7, 5.8, 5.7, 7.3, 2.1, 0.4, 0.3, 16.1, 11.9, 12.0, 11.7, 11.8, 10.0, 10.1, 9.8, 15.1, 14.7, 14.8, 10.2, 10.3, 1.2, 9.9, 1.9, 3.4, 14.6, 0.6, 5.1, 5.2, 7.5, 19.4, 10.7, 10.8, 10.9, 0.5, 16.3, 16.2, 16.0, 16.6, 12.4, 11.0, 1.7, 1.6, 2.4, 11.6, 3.9, 3.8, 14.5, 11.1]
from itertools import *
groups = (list(group) for _, group in groupby(reversed(sorted(speeds)), int))
print next((s[0] for s in groups
if 5 <= s[0] <= 30 and (s[0] <= 14 or len(s) >= 10)),
False)
Just in case all of these look odd to you, here's one close to your original. Just looking at the speeds from fastest to slowest and returning the first that matches the requirements:
def f(speeds):
count = Counter(map(int, speeds))
for speed in reversed(sorted(speeds)):
if 5 <= speed <= 30 and (speed <= 14 or count[int(speed)] >= 10):
return speed
return False
Btw, your definition of "the true maximum speed" seems rather odd to me. How about just looking at a certain percentile? Maybe like this:
print sorted(speeds)[len(speeds) * 9 // 10]
I'm not sure if this is faster, but it is shorter, and I think it achieves your requirements. It uses Counter.
from collections import Counter
import math
def valid(item):
speed,count = item
return speed <= 30 and (speed <= 13 or count >= 10)
speeds = [4,3,1,3,4,5,6,7,14,16,18,19,20,34,5,4,3,2,12,58,14,14,14]
speeds = map(math.floor,speeds)
counts = Counter(speeds)
max_valid_speed = max(filter(valid,counts.items()))
Result: max_valid_speed == (12,1)
Using your sort idea we can start at the end of the list at the numbers less than 30, returning on the first number that matched the criteria or returning False:
from collections import Counter
def f(speeds):
# get speeds that satisfy the range
rev = [speed for speed in speeds if 5 <= speed < 30]
rev.sort(reverse=True)
c = Counter((int(v) for v in rev))
for speed in rev:
# will hit highest numbers first
# so return first that matches
if speed > 14 and c[int(speed)] > 9 or speed < 15:
return speed
# we did not find any speed that matched our requirement
return False
Output for your list v:
In [70]: f(v)
Out[70]: 19.9
Without sorting you could use a dict, depending on your what your data is like will decide which is best, it will work for all cases including an empty list:
def f_dict(speeds):
d = defaultdict(lambda: defaultdict(lambda: 0, {}))
for speed in speeds:
key = int(speed)
d[key]["count"] += 1
if speed > d[key]["speed"]:
d[key]["speed"] = speed
filt = max(filter(lambda x: (15 <= x[0] < 30 and
x[1]["count"] > 9 or x[0] < 15), d.items()), default=False)
return filt[1]["speed"] if filt else False
Output:
In [95]: f_dict(v)
Out[95]: 19.9