I have a DataFrame with the following schema:
+--------------------+--------+-------+-------+-----------+--------------------+
| userid|datadate| runid|variant|device_type| prediction|
+--------------------+--------+-------+-------+-----------+--------------------+
|0001d15b-e2da-4f4...|20220111|1196752| 1| Mobile| 0.8827571312010658|
|00021723-2a0d-497...|20220111|1196752| 1| Mobile| 0.30763173370229735|
|00021723-2a0d-497...|20220111|1196752| 0| Mobile| 0.5336206154783815|
I would like to perform the following operation:
I want to do for each "runid", for each "device_type", some calculations with variant==1 and variant==0, including a resampling loop.
The ultimate goal is to store these calculations in another DF.
So in a naive approach the code would look like that:
for runid in df.select('runid').distinct().rdd.flatMap(list).collect():
for device in ["Mobile","Desktop"]:
a_variant = df.filter((df.runid == runid) & (df.device_type == device) & (df.variant == 0))
b_variant = df.filter((df.runid == runid) & (df.device_type == device) & (df.variant == 1))
## do some more calculations here
# bootstrap loop:
for samp in range(100):
sampled_vector_a = a_variant.select("prediction").sample(withReplacement = True, fraction = 1.0, seed = 123)
sampled_vector_b = b_variant.select("prediction").sample(withReplacement = True, fraction = 1.0, seed = 123)
## do some more calculations here
## do some more calculations here
## store calculations in a new DataFrame
Currently the process is too slow.
How can I optimize this process by utilizing spark in the best way?
Thanks!
Here is a way to sample from each group in a dataframe after applying groupBy.
from pyspark.sql import SparkSession
import pandas as pd
spark = SparkSession.builder.appName("Demo").getOrCreate()
df = spark.createDataFrame(data,columns)
data = [["uid1","runid1",1,"Mobile",0.8],["uid2","runid1",1,"Mobile",0.3],
["uid3","runid1",0,"Mobile",0.5],["uid4","runid2",0,"Mobile",0.7],
["uid5","runid2",0,"Mobile",0.9]]
columns = ["userid","runid","variant","device_type","prediction"]
df.show()
# +------+------+-------+-----------+----------+
# |userid| runid|variant|device_type|prediction|
# +------+------+-------+-----------+----------+
# | uid1|runid1| 1| Mobile| 0.8|
# | uid2|runid1| 1| Mobile| 0.3|
# | uid3|runid1| 0| Mobile| 0.5|
# | uid4|runid2| 0| Mobile| 0.7|
# | uid5|runid2| 0| Mobile| 0.9|
# +------+------+-------+-----------+----------+
Define a sampling function that is going to be called by applyInPandas. The function my_sample extracts one sample for each input dataframe:
def my_sample(key, df):
x = df.sample(n=1)
return x
applyInPandas also needs a schema for its output, since it is returning the whole dataframe it will have the same fields as df:
from pyspark.sql.types import *
schema = StructType([StructField('userid', StringType()),
StructField('runid', StringType()),
StructField('variant', LongType()),
StructField('device_type', StringType()),
StructField('prediction', DoubleType())])
Just to check, try grouping the data, there are three groups:
df.groupby("runid", "device_type", "variant").mean("prediction").show()
# +------+-----------+-------+---------------+
# | runid|device_type|variant|avg(prediction)|
# +------+-----------+-------+---------------+
# |runid1| Mobile| 0| 0.5|
# |runid1| Mobile| 1| 0.55|
# |runid2| Mobile| 0| 0.8|
# +------+-----------+-------+---------------+
Now apply my_sample to each group using applyInPandas:
df.groupby("runid","device_type","variant").applyInPandas(my_sample, schema=schema).show()
# +------+------+-------+-----------+----------+
# |userid| runid|variant|device_type|prediction|
# +------+------+-------+-----------+----------+
# | uid3|runid1| 0| Mobile| 0.5|
# | uid2|runid1| 1| Mobile| 0.3|
# | uid4|runid2| 0| Mobile| 0.7|
# +------+------+-------+-----------+----------+
Note: I used applyInPandas since pyspark.sql.GroupedData.apply.html is deprecated
Related
I have a df like this one:
df = spark.createDataFrame(
[("1", "Apple", "cat"), ("2", "2.", "house"), ("3", "<strong>text</strong>", "HeLlo 2.5")],
["id", "text1", "text2"])
+---+---------------------+---------+
| id| text1| text2|
+---+---------------------+---------+
| 1| Apple| cat|
| 2| 2.| house|
| 3|<strong>text</strong>|HeLlo 2.5|
+---+---------------------+---------+
multiple functions to clean text like
def remove_html_tags(text):
document = html.fromstring(text)
return " ".join(etree.XPath("//text()")(document))
def lowercase(text):
return text.lower()
def remove_wrong_dot(text):
return re.sub(r'(?<!\d)[.,;:]|[.,;:](?!\d)', ' ', text)
and a list of columns to clean
COLS = ["text1", "text2"]
I would like to apply the functions to the columns in the list and also keep the original text
+---+---------------------+-----------+---------+-----------+
| id| text1|text1_clean| text2|text2_clean|
+---+---------------------+-----------+---------+-----------+
| 1| Apple| apple| cat| cat|
| 2| 2.| 2| house| house|
| 3|<strong>text</strong>| text|HeLlo 2.5| hello 2.5|
+---+---------------------+-----------+---------+-----------+
I already have an approach using UDF but it is not very efficient. I've been trying something like:
rdds = []
for col in TEXT_COLS:
rdd = df.rdd.map(lambda x: (x[col], lowercase(x[col])))
rdds.append(rdd.collect())
return df
My idea would be to join all rdds in the list but I don't know how efficient this would be or how to list more functions.
I appreciate any ideas or suggestions.
EDIT: Not all transformations can be done with regexp_replace. For example, the text can include nested html labels and in that case a simple replace wouldn't work or I don't want to replace all dots, only those at the end or beginning of substrings
Spark built-in functions can do all the transformations you wanted
from pyspark.sql import functions as F
cols = ["text1", "text2"]
for c in cols:
df = (df
.withColumn(f'{c}_clean', F.lower(c))
.withColumn(f'{c}_clean', F.regexp_replace(f'{c}_clean', '<[^>]+>', ''))
.withColumn(f'{c}_clean', F.regexp_replace(f'{c}_clean', '(?<!\d)[.,;:]|[.,;:](?!\d)', ''))
)
+---+--------------------+---------+-----------+-----------+
| id| text1| text2|text1_clean|text2_clean|
+---+--------------------+---------+-----------+-----------+
| 1| Apple| cat| apple| cat|
| 2| 2.| house| 2| house|
| 3|<strong>text</str...|HeLlo 2.5| text| hello 2.5|
+---+--------------------+---------+-----------+-----------+
I have below list of dictionaries
results =
[
{
"type:"check_datatype",
"kwargs":{
"table":"cars","column_name":"vin","d_type":"string"
}
},
{
"type":"check_emptystring",
"kwargs":{
"table":"cars","column_name":"vin"
}
},
{
"type:"check_null",
"kwargs":{
"table":"cars","columns":["vin","index"]
}
}
]
I want to create two different pyspark dataframe with below schema -
args_id column in results table will be same when we have unique pair of (type,kwargs). This JSON has to be run on a daily basis and hence if it find out same pair of (type,kwargs) again, it should give the same args_id value.
Till now, i have written this code -
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Window
check_type_results = [[elt['type']] for elt in results]
checkColumns = ['type']
spark = SparkSession.builder.getOrCreate()
checkResultsDF = spark.createDataFrame(data=check_type_results, schema=checkColumns)
checkResultsDF = checkResultsDF.withColumn("time", F.current_timestamp())
checkResultsDF = checkResultsDF.withColumn("args_id", F.row_number().over(Window.orderBy(F.monotonically_increasing_id())))
checkResultsDF.printSchema()
Now, with my code , i am always getting args_id in increasing order which is correct for the first run but if i again run the json on next day or may be on same day and in the json file some pair of (type,kwargs) comes which has already come before so i should be using the same args_id for that pair.
If some pair (type,kwargs) has no entry in Arguments table, then only i will insert into arguments table but if the pair (type,kwargs) already exists in arguments table, then no insert should happen there.
Once these two dataframes are filled properly, then i want to load them into separate delta tables.
Hashcode column in arguments table is unique identifier for each "kwargs".
Issues
Your schema is a bit incomplete. A more detailed schema will allow you to take advantage of more spark features. See Solution below using spark-sql and pyspark. Instead of window functions that require ordered partitions, you may take advantage of a few of the table generating array functions such as explode and posexplode available in spark-sql. As it pertains to writing to your delta table, you may see examples here
Solution 1 : Using Spark SQL
Setup
from pyspark.sql.types import ArrayType,StructType, StructField, StringType, MapType
from pyspark.sql import Row, SparkSession
sparkSession = SparkSession.builder.appName("Demo").getOrCreate()
Schema Definition
Your sample record is an Array of Structs/Objects where the kwargs is a Maptype with optional keys. NB. The True indicates optional and should assist when there are missing keys or entries with different formats
schema = StructType([
StructField("entry",ArrayType(
StructType([
StructField("type",StringType(),True),
StructField("kwargs",MapType(StringType(),StringType()),True)
])
),True)
])
Reproducible Example
result_entry =[
{
"type":"check_datatype",
"kwargs":{
"table":"cars","column_name":"vin","d_type":"string"
}
},
{
"type":"check_emptystring",
"kwargs":{
"table":"cars","column_name":"vin"
}
},
{
"type":"check_null",
"kwargs":{
"table":"cars","columns":["vin","index"]
}
}
]
df_results = sparkSession.createDataFrame([Row(entry=result_entry)],schema=schema)
df_results.createOrReplaceTempView("df_results")
df_results.show()
Results
+--------------------+
| entry|
+--------------------+
|[{check_datatype,...|
+--------------------+
Results Table Generation
I've used current_date to capture the current date however you may change this based on your pipeline.
results_table = sparkSession.sql("""
WITH raw_results as (
SELECT
posexplode(entry),
current_date as time
FROM
df_results
)
SELECT
col.type as Type,
time,
pos as arg_id
FROM
raw_results
""")
results_table.show()
Results
+-----------------+----------+------+
| Type| time|arg_id|
+-----------------+----------+------+
| check_datatype|2021-03-31| 0|
|check_emptystring|2021-03-31| 1|
| check_null|2021-03-31| 2|
+-----------------+----------+------+
Arguments Table Generation
args_table = sparkSession.sql("""
WITH raw_results as (
SELECT
posexplode(entry)
FROM
df_results
),
raw_arguments AS (
SELECT
explode(col.kwargs),
pos as args_id
FROM
raw_results
),
raw_arguments_before_array_check AS (
SELECT
args_id,
key as bac_key,
value as bac_value
FROM
raw_arguments
),
raw_arguments_after_array_check AS (
SELECT
args_id,
bac_key,
bac_value,
posexplode(split(regexp_replace(bac_value,"[\\\[\\\]]",""),","))
FROM
raw_arguments_before_array_check
)
SELECT
args_id,
bac_key as key,
col as value,
CASE
WHEN bac_value LIKE '[%' THEN pos
ELSE NULL
END as list_index,
abs(hash(args_id, bac_key,col,pos)) as hashcode
FROM
raw_arguments_after_array_check
""")
args_table.show()
Results
+-------+-----------+------+----------+----------+
|args_id| key| value|list_index| hashcode|
+-------+-----------+------+----------+----------+
| 0| d_type|string| null| 216841494|
| 0|column_name| vin| null| 502458545|
| 0| table| cars| null|1469121505|
| 1|column_name| vin| null| 604007568|
| 1| table| cars| null| 784654488|
| 2| columns| vin| 0|1503105124|
| 2| columns| index| 1| 454389776|
| 2| table| cars| null| 858757332|
+-------+-----------+------+----------+----------+
Solution 2: Using UDF
You may also define user-defined-functions with your already implemented python logic and apply this with spark
Setup
We will define our functions to create our results and arguments table here. I have chosen to create generator type functions but this is optional.
result_entry =[
{
"type":"check_datatype",
"kwargs":{
"table":"cars","column_name":"vin","d_type":"string"
}
},
{
"type":"check_emptystring",
"kwargs":{
"table":"cars","column_name":"vin"
}
},
{
"type":"check_null",
"kwargs":{
"table":"cars","columns":["vin","index"]
}
}
]
import json
result_entry_str = json.dumps(result_entry)
result_entry_str
def extract_results_table(entry,current_date=None):
if current_date is None:
from datetime import date
current_date = str(date.today())
if type(entry)==str:
import json
entry = json.loads(entry)
for arg_id,arg in enumerate(entry):
yield {
"Type":arg["type"],
"time":current_date,
"args_id":arg_id
}
def extract_arguments_table(entry):
if type(entry)==str:
import json
entry = json.loads(entry)
for arg_id,arg in enumerate(entry):
if "kwargs" in arg:
for arg_entry in arg["kwargs"]:
orig_key,orig_value = arg_entry, arg["kwargs"][arg_entry]
if type(orig_value)==list:
for list_index,value in enumerate(orig_value):
yield {
"args_id":arg_id,
"key":orig_key,
"value":value,
"list_index":list_index,
"hash_code": hash((arg_id,orig_key,value,list_index))
}
else:
yield {
"args_id":arg_id,
"key":orig_key,
"value":orig_value,
"list_index":None,
"hash_code": hash((arg_id,orig_key,orig_value,"null"))
}
Pyspark Setup
from pyspark.sql.functions import udf,col,explode
from pyspark.sql.types import StructType,StructField,IntegerType,StringType, ArrayType
results_table_schema = ArrayType(StructType([
StructField("Type",StringType(),True),
StructField("time",StringType(),True),
StructField("args_id",IntegerType(),True)
]),True)
arguments_table_schema = ArrayType(StructType([
StructField("args_id",IntegerType(),True),
StructField("key",StringType(),True),
StructField("value",StringType(),True),
StructField("list_index",IntegerType(),True),
StructField("hash",StringType(),True)
]),True)
extract_results_table_udf = udf(lambda entry,current_date=None : [*extract_results_table(entry,current_date)],results_table_schema)
extract_arguments_table_udf = udf(lambda entry: [*extract_arguments_table(entry)],arguments_table_schema)
# this is useful if you intend to use your functions in spark-sql
sparkSession.udf.register('extract_results_table',extract_results_table_udf)
sparkSession.udf.register('extract_arguments_table',extract_arguments_table_udf)
Spark Data Frame
df_results_1 = sparkSession.createDataFrame([Row(entry=result_entry_str)],schema="entry string")
df_results_1.createOrReplaceTempView("df_results_1")
df_results_1.show()
Extracting Results Table
# Using Spark SQL
sparkSession.sql("""
WITH results_table AS (
select explode(extract_results_table(entry)) as entry FROM df_results_1
)
SELECT entry.* from results_table
""").show()
# Just python
df_results_1.select(
explode(extract_results_table_udf(df_results_1.entry)).alias("entry")
).selectExpr("entry.*").show()
Output
+-----------------+----------+-------+
| Type| time|args_id|
+-----------------+----------+-------+
| check_datatype|2021-03-31| 0|
|check_emptystring|2021-03-31| 1|
| check_null|2021-03-31| 2|
+-----------------+----------+-------+
+-----------------+----------+-------+
| Type| time|args_id|
+-----------------+----------+-------+
| check_datatype|2021-03-31| 0|
|check_emptystring|2021-03-31| 1|
| check_null|2021-03-31| 2|
+-----------------+----------+-------+
Extracting Results Table
# Using spark sql
sparkSession.sql("""
WITH arguments_table AS (
select explode(extract_arguments_table(entry)) as entry FROM df_results_1
)
SELECT entry.* from arguments_table
""").show()
# Just python
df_results_1.select(
explode(extract_arguments_table_udf(df_results_1.entry)).alias("entry")
).selectExpr("entry.*").show()
Output
+-------+-----------+------+----------+----+
|args_id| key| value|list_index|hash|
+-------+-----------+------+----------+----+
| 0| table| cars| null|null|
| 0|column_name| vin| null|null|
| 0| d_type|string| null|null|
| 1| table| cars| null|null|
| 1|column_name| vin| null|null|
| 2| table| cars| null|null|
| 2| columns| vin| 0|null|
| 2| columns| index| 1|null|
+-------+-----------+------+----------+----+
+-------+-----------+------+----------+----+
|args_id| key| value|list_index|hash|
+-------+-----------+------+----------+----+
| 0| table| cars| null|null|
| 0|column_name| vin| null|null|
| 0| d_type|string| null|null|
| 1| table| cars| null|null|
| 1|column_name| vin| null|null|
| 2| table| cars| null|null|
| 2| columns| vin| 0|null|
| 2| columns| index| 1|null|
+-------+-----------+------+----------+----+
Reference
Spark SQL Functions
Delta Batch Writes
I'm trying to replicate the following SAS code in PySpark:
PROC RANK DATA = aud_baskets OUT = aud_baskets_ranks GROUPS=10 TIES=HIGH;
BY customer_id;
VAR expenditure;
RANKS basket_rank;
RUN;
The idea is to rank all expenditures under each customer_id block. The data would look like this:
+-----------+--------------+-----------+
|customer_id|transaction_id|expenditure|
+-----------+--------------+-----------+
| A| 1| 34|
| A| 2| 90|
| B| 1| 89|
| A| 3| 6|
| B| 2| 8|
| B| 3| 7|
| C| 1| 96|
| C| 2| 9|
+-----------+--------------+-----------+
In PySpark, I tried this:
spendWindow = Window.partitionBy('customer_id').orderBy(col('expenditure').asc())
aud_baskets = (aud_baskets_ranks.withColumn('basket_rank', ntile(10).over(spendWindow)))
The problem is that PySpark doesn't let the user change the way it will handle Ties, like SAS does (that I know of). I need to set this behavior in PySpark so that values are moved up to the next tier each time one of those edge cases occur, as oppose to dropping them to the rank below.
Or is there a way to custom write this approach?
Use dense_rank it will give same rank in case of ties and next rank will not be skipped
ntile function split the group of records in each partition into n parts. In your case which is 10
from pyspark.sql.functions import dense_rank
spendWindow = Window.partitionBy('customer_id').orderBy(col('expenditure').asc())
aud_baskets = aud_baskets_ranks.withColumn('basket_rank',dense_rank.over(spendWindow))
Try The following code. It is generated by an automated tool called SPROCKET. It should take care of ties.
df = (aud_baskets)
for (colToRank,rankedName) in zip(['expenditure'],['basket_rank']):
wA = Window.orderBy(asc(colToRank))
df_w_rank = (df.withColumn('raw_rank', rank().over(wA)))
ties = df_w_rank.groupBy('raw_rank').count().filter("""count > 1""")
df_w_rank = (df_w_rank.join(ties,['raw_rank'],'left').withColumn(rankedName,expr("""case when count is not null
then (raw_rank + count - 1) else
raw_rank end""")))
rankedNameGroup = rankedName
n = df_w_rank.count()
df_with_rank_groups = (df_w_rank.withColumn(rankedNameGroup,expr("""FLOOR({rankedName}
*{k}/({n}+1))""".format(k=10, n=n,
rankedName=rankedName))))
df = df_with_rank_groups
aud_baskets_ranks = df_with_rank_groups.drop('raw_rank', 'count')
I wrote a pyspark implementation of reading row over row to incrementally (and recursively) multiply a column value in sequence. Due to platform limitations on our side, I need to convert this to Scala now without UDAF. I looked at this implementation, but that one takes up long as the number of year_months grow as it needs # of temp tables as the # of year_months.
There are around 100 year_months and 70 departments giving total number of rows in this dataframe to be 7000. We need to take up the starting value (by first year month in the sequence) for each department and multiply it with next row value. The resulting multiplied factor needs to be multiplied over with next row and so on.
Example data:
department, productivity_ratio, year_month
101,1.00,2013-01-01
101,0.98,2013-02-01
101,1.01,2013-03-01
101,0.99,2013-04-01
...
102,1.00,2013-01-01
102,1.02,2013-02-01
102,0.96,2013-03-01
...
Expected result:
department,productivity_ratio,year_month,chained_productivity_ratio
101,1.00,2013-01-01,1.00
101,0.98,2013-02-01,0.98 (1.00*0.98)
101,1.01,2013-03-01,0.9898 (1.00*0.98*1.01)
101,0.99,2013-04-01,0.9799 (1.00*0.98*1.01*0.99)
...
102,1.00,2013-01-01,1.00 (reset to 1.00 as starting point as department name changed in sequence)
102,1.02,2013-02-01,1.02 (1.00*1.02)
102,0.96,2013-03-01,0.9792 (1.00*1.02*0.96)
...
Is there any way to implement this in faster way in scala either converting this into a loop over departments and looking at the productivity_ratio as a sequence to multiply with previous value or by changing the dataframe into a different data structure to avoid running into distributed sequencing problems.
Existing pyspark code:
%pyspark
import pandas as pd
import numpy as np
import StringIO
inputParquet = "s3://path/to/parquet/files/"
inputData = spark.read.parquet(inputParquet)
inputData.printSchema
root
|-- department: string
|-- productivity_ratio: double
|-- year_month: date
inputSorted=inputData.sort('department', 'year_month')
inputSortedNotnull=inputSorted.dropna()
finalInput=inputSortedNotnull.toPandas()
prev_dept = 999
prev_productivity_ratio = 1
new_productivity_chained = []
for t in finalInput.itertuples():
if prev_dept == t[1]:
new_productivity_chained.append(t[2] * prev_productivity_ratio)
prev_productivity_ratio = t[2] * prev_productivity_ratio
else:
prev_productivity_ratio = 1
new_productivity_chained.append(prev_productivity_ratio)
prev_dept = t[1]
productivityChained = finalInput.assign(chained_productivity=new_productivity_chained)
You can use window lag function and do exp(sum(log(<column>))) to calculate the chained_productivity_ratio and all the functions we are using are spark inbuilt functions the performance will be great!
Example:
In Pyspark:
df.show()
#+----------+------------------+----------+
#|department|productivity_ratio|year_month|
#+----------+------------------+----------+
#| 101| 1.00|2013-01-01|
#| 101| 0.98|2013-02-01|
#| 101| 1.01|2013-03-01|
#| 101| 0.99|2013-04-01|
#| 102| 1.00|2013-01-01|
#| 102| 1.02|2013-02-01|
#| 102| 0.96|2013-03-01|
#+----------+------------------+----------+
from pyspark.sql.functions import *
from pyspark.sql import Window
w = Window.partitionBy("department").orderBy("year_month")
df.withColumn("chained_productivity_ratio",exp(sum(log(col("productivity_ratio"))).over(w))).show()
#+----------+------------------+----------+--------------------------+
#|department|productivity_ratio|year_month|chained_productivity_ratio|
#+----------+------------------+----------+--------------------------+
#| 101| 1.00|2013-01-01| 1.0|
#| 101| 0.98|2013-02-01| 0.98|
#| 101| 1.01|2013-03-01| 0.9898|
#| 101| 0.99|2013-04-01| 0.9799019999999999|
#| 102| 1.00|2013-01-01| 1.0|
#| 102| 1.02|2013-02-01| 1.02|
#| 102| 0.96|2013-03-01| 0.9792|
#+----------+------------------+----------+--------------------------+
In Scala:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
val w = Window.partitionBy("department").orderBy("year_month")
df.withColumn("chained_productivity_ratio",exp(sum(log(col("productivity_ratio"))).over(w))).show()
//+----------+------------------+----------+--------------------------+
//|department|productivity_ratio|year_month|chained_productivity_ratio|
//+----------+------------------+----------+--------------------------+
//| 101| 1.00|2013-01-01| 1.0|
//| 101| 0.98|2013-02-01| 0.98|
//| 101| 1.01|2013-03-01| 0.9898|
//| 101| 0.99|2013-04-01| 0.9799019999999999|
//| 102| 1.00|2013-01-01| 1.0|
//| 102| 1.02|2013-02-01| 1.02|
//| 102| 0.96|2013-03-01| 0.9792|
//+----------+------------------+----------+--------------------------+
I have written a data preprocessing codes in Pandas UDF in PySpark. I'm using lambda function to extract a part of the text from all the records of a column.
Here is how my code looks like:
#pandas_udf("string", PandasUDFType.SCALAR)
def get_X(col):
return col.apply(lambda x: x.split(',')[-1] if len(x.split(',')) > 0 else x)
df = df.withColumn('X', get_first_name(df.Y))
This is working fine and giving the desired results. But I need to write the same piece of logic in Spark equivalent code. Is there a way to do it? Thanks.
I think one function substring_index is enough for this particular task:
from pyspark.sql.functions import substring_index
df = spark.createDataFrame([(x,) for x in ['f,l', 'g', 'a,b,cd']], ['c1'])
df2.withColumn('c2', substring_index('c1', ',', -1)).show()
+------+---+
| c1| c2|
+------+---+
| f,l| l|
| g| g|
|a,b,cd| cd|
+------+---+
Given the following DataFrame df:
df.show()
# +-------------+
# | BENF_NME|
# +-------------+
# | Doe, John|
# | Foo|
# |Baz, Quux,Bar|
# +-------------+
You can simply use regexp_extract() to select the first name:
from pyspark.sql.functions import regexp_extract
df.withColumn('First_Name', regexp_extract(df.BENF_NME, r'(?:.*,\s*)?(.*)', 1)).show()
# +-------------+----------+
# | BENF_NME|First_Name|
# +-------------+----------+
# | Doe, John| John|
# | Foo| Foo|
# |Baz, Quux,Bar| Bar|
# +-------------+----------+
If you don't care about possible leading spaces, substring_index() provides a simple alternative to your original logic:
from pyspark.sql.functions import substring_index
df.withColumn('First_Name', substring_index(df.BENF_NME, ',', -1)).show()
# +-------------+----------+
# | BENF_NME|First_Name|
# +-------------+----------+
# | Doe, John| John|
# | Foo| Foo|
# |Baz, Quux,Bar| Bar|
# +-------------+----------+
In this case the first row's First_Name has a leading space:
df.withColumn(...).collect()[0]
# Row(BENF_NME=u'Doe, John', First_Name=u' John'
If you still want to use a custom function, you need to create a user-defined function (UDF) using udf():
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
get_first_name = udf(lambda s: s.split(',')[-1], StringType())
df.withColumn('First_Name', get_first_name(df.BENF_NME)).show()
# +-------------+----------+
# | BENF_NME|First_Name|
# +-------------+----------+
# | Doe, John| John|
# | Foo| Foo|
# |Baz, Quux,Bar| Bar|
# +-------------+----------+
Note that UDFs are slower than the built-in Spark functions, especially Python UDFs.
You can do the same using when to implement if-then-else logic:
First split the column, then compute its size. If the size is greater than 0, take the last element from the split array. Otherwise, return the original column.
from pyspark.sql.functions import split, size, when
def get_first_name(col):
col_split = split(col, ',')
split_size = size(col_split)
return when(split_size > 0, col_split[split_size-1]).otherwise(col)
As an example, suppose you had the following DataFrame:
df.show()
#+---------+
#| BENF_NME|
#+---------+
#|Doe, John|
#| Madonna|
#+---------+
You can call the new function just as before:
df = df.withColumn('First_Name', get_first_name(df.BENF_NME))
df.show()
#+---------+----------+
#| BENF_NME|First_Name|
#+---------+----------+
#|Doe, John| John|
#| Madonna| Madonna|
#+---------+----------+