Flatten Nested Json String Column Table into tabular format - python

I am currently trying to get a flatten a data in databricks table. Since some of the columns are deeply nested and is of 'String' type, i couldn't use explode function.
My current dataframe looks like this:
display(df)
account
applied
applylist
aracct
Internal Id
{"id":"1","name":"ABC","type":null}
2500.00
{"apply":[{"total":20.00,"applyDate":"2021-07-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":200.0},{"total":25.00,"applyDate":"2021-07-15T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":25.0}],"replaceAll":false}
{"internalId":"121","name":"CMS","type":null}
101
{"id":"2","name":"DEF","type":null}
1500.00
{"apply":[{"total":30.00,"applyDate":"2021-08-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":250.0},{"total":35.00,"applyDate":"2021-09-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":350.0}],"replaceAll":false}
{"internalId":"121","name":"BMS","type":null}
102
My dataframe schema looks like this:
df.printSchema()
|--account: string (nullable = true)
|--applied: decimal(38,6) (nullable = true)
|-- applylist: string (nullable = true)
|-- aracct: string (nullable = true)
How can I flatten above table and store individual record on tabular format, not in the nested.
Expected Output:
account.id
account.name
account.type
applied
applylist.apply.total
applylist.apply.applydate
applylist.apply.currency
applylist.apply.apply
applylist.apply.discamount
applylist.apply.line
applylist.apply.type
applylist.apply.amount
applylist.replaceAll
1
ABC
null
2500.00
20.00
2021-07-13T07:00:00Z
USA
true
null
0
Invoice
200.0
false
2
DEF
null
1500.00
30.00
2021-08-13T07:00:00Z
USA
true
null
0
Invoice
250.0
false
This is my Scala code:
import org.apache.spark.sql.functions._
import spark.implicits._
val df = spark.sql("select * from ns_db_integration.transaction")
display(df.select($"applied" as "Applied", $"applylist", explode($"account"))
.withColumn("Account.Id" ,$"col.id")
.withColumn("Account.Name",$"col.name")
.withColumn("Account.Type",$"col.type").drop($"col")
.select($"*",$"applylist.*")
.drop($"applylist")
.select($"*",explode($"apply"))
.drop($"apply")
.withColumn("Total",$"col.total")
.withColumn("ApplyDate",$"col.applyDate")
.drop($"col")
)
Error in Scala Code
Also tried json_tuple function in Pyspark. Which didn't work as i expected. All applylist column value becomes null.
from pyspark.sql.functions import json_tuple,from_json,get_json_object, explode,col
df.select(col("applied"),json_tuple(col("applylist"),"apply.total","apply.applyDate","apply.currency","apply.apply")) \
.toDF("applied","applylist.apply.total","applylist.apply.applyDate","applylist.apply.currency","applylist.apply.apply") \
.show(truncate=False)
Output of Pyspark Code

Using Pyspark, see below logic -
Input Data
str1 = """account applied applylist aracct
{"id":"1","name":"ABC","type":null} 2500.00 {"apply":[{"total":20.00,"applyDate":"2021-07-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":200.0}],"replaceAll":false} {"internalId":"121","name":"CMS","type":null}
{"id":"2","name":"DEF","type":null} 1500.00 {"apply":[{"total":30.00,"applyDate":"2021-08-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":250.0}],"replaceAll":false} {"internalId":"121","name":"BMS","type":null}"""
import pandas as pd
from io import StringIO
pdf = pd.read_csv(StringIO(str1), sep = '\t')
df = spark.createDataFrame(pdf)
df.show(truncate=False)
+-----------------------------------+-------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------+
|account |applied|applylist |aracct |
+-----------------------------------+-------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------+
|{"id":"1","name":"ABC","type":null}|2500.0 |{"apply":[{"total":20.00,"applyDate":"2021-07-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":200.0}],"replaceAll":false}|{"internalId":"121","name":"CMS","type":null}|
|{"id":"2","name":"DEF","type":null}|1500.0 |{"apply":[{"total":30.00,"applyDate":"2021-08-13T07:00:00Z","currency":"USA","apply":true,"discAmt":null,"line":0,"type":"Invoice","amount":250.0}],"replaceAll":false}|{"internalId":"121","name":"BMS","type":null}|
+-----------------------------------+-------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------+
Required Output
from pyspark.sql.functions import *
from pyspark.sql.types import *
schema_account = StructType([StructField("id", StringType(), True),
StructField("name", StringType(), True),
StructField("type", StringType(), True)
])
df1 = (
df.select(from_json(col("account"), schema_account).alias("account"),"applied",from_json(col("applylist"), MapType(StringType(), StringType())))
.select("account.*","applied","entries.apply", "entries.replaceAll")
.select("id", "name", "type", "applied" , from_json(col("apply"), ArrayType(MapType(StringType(), StringType()))).alias("apply"), "replaceAll")
.select("id", "name", "type", "applied" , explode("apply").alias("apply"), "replaceAll")
.select("id", "name", col("type").alias("type1"), "applied" , explode("apply"), "replaceAll")
.groupBy("id", "name", "type1", "applied", "replaceAll").pivot("key").agg(first("value"))
.withColumnRenamed("id", "account.id")
.withColumnRenamed("name", "account.name")
.withColumnRenamed("type1", "account.type")
.withColumnRenamed("total", "applylist.apply.total")
.withColumnRenamed("applyDate", "applylist.apply.applyDate")
.withColumnRenamed("currency", "applylist.apply.currency")
.withColumnRenamed("apply", "applylist.apply.apply")
.withColumnRenamed("discAmt", "applylist.apply.discAmt")
.withColumnRenamed("line", "applylist.apply.line")
.withColumnRenamed("type", "applylist.apply.type")
.withColumnRenamed("amount", "applylist.apply.amount")
)
df1.select("`account.id`" ,"`account.name`" ,"`account.type`" ,"applied" ,"`applylist.apply.total`" ,"`applylist.apply.applyDate`" ,"`applylist.apply.currency`" ,"`applylist.apply.apply`" ,"`applylist.apply.discAmt`" ,"`applylist.apply.line`" ,"`applylist.apply.type`" ,"`applylist.apply.amount`" ,"`replaceAll`").show(truncate=False)
+----------+------------+------------+-------+---------------------+-------------------------+------------------------+---------------------+-----------------------+--------------------+--------------------+----------------------+----------+
|account.id|account.name|account.type|applied|applylist.apply.total|applylist.apply.applyDate|applylist.apply.currency|applylist.apply.apply|applylist.apply.discAmt|applylist.apply.line|applylist.apply.type|applylist.apply.amount|replaceAll|
+----------+------------+------------+-------+---------------------+-------------------------+------------------------+---------------------+-----------------------+--------------------+--------------------+----------------------+----------+
|1 |ABC |null |2500.0 |20.0 |2021-07-13T07:00:00Z |USA |true |null |0 |Invoice |200.0 |false |
|2 |DEF |null |1500.0 |30.0 |2021-08-13T07:00:00Z |USA |true |null |0 |Invoice |250.0 |false |
+----------+------------+------------+-------+---------------------+-------------------------+------------------------+---------------------+-----------------------+--------------------+--------------------+----------------------+----------+

Related

Flatten Map Type in Pyspark

I have a dataframe as below
+-------------+--------------+----+-------+-----------------------------------------------------------------------------------+
|empId |organization |h_cd|status |additional |
+-------------+--------------+----+-------+-----------------------------------------------------------------------------------+
|FTE:56e662f |CATENA |0 |CURRENT|{hr_code -> 84534, bgc_val -> 170187, interviewPanel -> 6372, meetingId -> 3671} |
|FTE:633e7bc |Data Science |0 |CURRENT|{hr_code -> 21036, bgc_val -> 170187, interviewPanel -> 764, meetingId -> 577} |
|FTE:d9badd2 |CATENA |0 |CURRENT|{hr_code -> 60696, bgc_val -> 88770} |
+-------------+--------------+----+-------+-----------------------------------------------------------------------------------+
I wanted to flatten it and create a dataframe as below -
+-------------+--------------+----+-------+------------+------------+-------------------+---------------+
|empId |organization |h_cd|status |hr_code |bgc_val |interviewPanel | meetingId |
+-------------+--------------+----+-------+------------+------------+-------------------+---------------+
|FTE:56e662f |CATENA |0 |CURRENT|84534 |170187 |6372 |3671 |
|FTE:633e7bc |Data Science |0 |CURRENT|21036 |170187 |764 |577 |
|FTE:d9badd2 |CATENA |0 |CURRENT|60696 |88770 |Null |Null |
+-------------+--------------+----+-------+------------+------------+-------------------+---------------+
My existing logic is as below
new_df = df.rdd.map(lambda x: (x.empId, x.h_cd, x.status ,x.data["hr_code"], x.data["bgc_val"], x.data["interviewPanel"], x.data["meetingId"], x.category)) \
.toDF(["emp_id","h_cd","status","hr_code","bgc_val","interview_panel","meeting_id","category"])
However, using this logic to create new_df and trying to write dataframe, im running into Error
org.apache.spark.api.python.PythonException: 'KeyError: 'interviewPanel''
This is caused due to the fact that there is not additional.interviewPanel in the map type field for the empId FTE:d9badd2.
Can someone suggest the best way to handle this and just adding null to the dataframe if the field key:val is not present in maptype field.
Thanks in advance!!
You just need to use getField function on map column,
df = spark.createDataFrame([("FTE:56e662f", "CATENA", 0, "CURRENT",
({"hr_code": 84534, "bgc_val": 170187, "interviewPanel": 6372, "meetingId": 3671})),
("FTE:633e7bc", "Data Science", 0, "CURRENT",
({"hr_code": 21036, "bgc_val": 170187, "interviewPanel": 764, "meetingId": 577})),
("FTE:d9badd2", "CATENA", 0, "CURRENT",
({"hr_code": 60696, "bgc_val": 88770}))],
["empId", "organization", "h_cd", "status", "additional"])
df.select("empId", "organization", "h_cd", "status",
col("additional").getField("hr_code").alias("hr_code"),
col("additional").getField("bgc_val").alias("bgc_val"),
col("additional").getField("interviewPanel").alias("interviewPanel"),
col("additional").getField("meetingId").alias("meetingId")
).show(truncate=False)
+-----------+------------+----+-------+-------+-------+--------------+---------+
|empId |organization|h_cd|status |hr_code|bgc_val|interviewPanel|meetingId|
+-----------+------------+----+-------+-------+-------+--------------+---------+
|FTE:56e662f|CATENA |0 |CURRENT|84534 |170187 |6372 |3671 |
|FTE:633e7bc|Data Science|0 |CURRENT|21036 |170187 |764 |577 |
|FTE:d9badd2|CATENA |0 |CURRENT|60696 |88770 |null |null |
+-----------+------------+----+-------+-------+-------+--------------+---------+
Just use getItem on the column. E.g.
df.select("*", F.col("additional").getItem("meetingId").alias("meetingId"))
You can also collect the key-names in a list to avoid using hardcoded values (useful when there are a number of keys).
allKeys = df.select(F.explode('additional')).select(F.collect_set("key").alias("key")).first().asDict().get("key")
df.select("*", *[F.col("additional").getItem(key).alias(key) for key in allKeys]).show()
Input:
Output:
Apache Spark Scala answer you can translate to pyspark as well
AFAIK you need to explode , group and pivot by key like below example
import org.apache.spark.sql.functions._
val df= Seq(
( "FTE:56e662f", "CATENA", 0, "CURRENT", Map("hr_code" -> 84534, "bgc_val" -> 170187, "interviewPanel" -> 6372, "meetingId" -> 3671) ),
( "FTE:633e7bc", "Data Science", 0, "CURRENT", Map("hr_code" -> 21036, "bgc_val" -> 170187, "interviewPanel" -> 764, "meetingId" -> 577) ),
( "FTE:d9badd2", "CATENA", 0, "CURRENT", Map("hr_code" -> 60696, "bgc_val" -> 88770) )).toDF("empId", "organization", "h_cd", "status", "additional")
val explodeddf = df.select($"empId", $"organization", $"h_cd", $"status", explode($"additional"))
val grpdf = explodeddf.groupBy($"empId", $"organization", $"h_cd", $"status").pivot("key").agg(first("value"))
val finaldf = grpdf.selectExpr("empId", "organization", "h_cd", "status", "hr_code","bgc_val","interviewPanel", "meetingId")
finaldf.show
Output :
+-----------+------------+----+-------+-------+-------+--------------+---------+
| empId|organization|h_cd| status|bgc_val|hr_code|interviewPanel|meetingId|
+-----------+------------+----+-------+-------+-------+--------------+---------+
|FTE:633e7bc|Data Science| 0|CURRENT| 170187| 21036| 764| 577|
|FTE:d9badd2| CATENA| 0|CURRENT| 88770| 60696| null| null|
|FTE:56e662f| CATENA| 0|CURRENT| 170187| 84534| 6372| 3671|
+-----------+------------+----+-------+-------+-------+--------------+---------+
You can yse .* to transform a struct column into fields columns:
df.select("empId", "organization", "h_cd", "status", "additional.*")

How can I separate data where no of fields not matching in each row using spark

I am trying to filter records which donot have expected number of fields in row below is my code
no_of_rows_in_each_column=3
delimiter = ","
input.csv
emp_id,emp_name,salary
1,"siva
Prasad",100
2,pavan,200,extra
3,prem,300
4,john
Expecetd output dataframes
Correct_input_data_frame
emp_id,emp_name,salary
1,"siva Prasad",100
3,prem,300
wrong_file.csv it is file
emp_id,emp_name,salary,no_of_fields
2,pavan,200,extra,4 fields in row 3 fields expected
4,john, 2 fields in row 3 expected
I tried this, seems able to read but len() function not working on rows.
input_df = (spark.read
.option("multiline", "true")
.option("quote", '"')
.option("header", "true")
.option("escape", "\\")
.option("escape", '"')
.csv('input.csv')
)
correct = input_df.(filter(len(row{}) = 3)
wrong_data = input_df.(filter(len(row{})<>3)
Add DROPMALFORMED mode to filter out "bad" lines:
df = (spark.read
.option("multiline", "true")
.option("mode", "DROPMALFORMED")
.option("ignoreLeadingWhiteSpace", False)
.option("quote", '"')
.option("header", "true")
.option("escape", "\\")
.option("escape", '"')
.csv('input.csv')
)
df.show()
+----------+-----------------+------+
| emp_id| emp_name|salary|
+----------+-----------------+------+
| 1|siva \n Prasad| 100|
| 3| prem| 300|
+----------+-----------------+------+
You should specify the schema, then you can use the columnNameOfCorruptRecordoption.
I've implemented it using Scala, but the Python implementation should be similar.
val df = spark.read
.schema("emp_id Long, emp_name String, salary Long, corrupted_record String")
.option("columnNameOfCorruptRecord", "corrupted_record")
.option("multiline", "true")
.option("ignoreLeadingWhiteSpace", false)
.option("quote", "\"")
.option("header", "true")
.option("escape", "\\")
.csv("input.csv")
df.show()
The result is:
+------+------------+------+-----------------+
|emp_id| emp_name|salary| corrupted_record|
+------+------------+------+-----------------+
| 1|siva\nPrasad| 100| null|
| 2| pavan| 200|2,pavan,200,extra|
| 3| prem| 300| null|
| 4| john| null| 4,john|
+------+------------+------+-----------------+
Now, it is pretty straightforward to filter correct and wrong data:
val correctDF = df.filter(col("corrupted_record").isNull)
val wrongDF = df.filter(col("corrupted_record").isNotNull)

Pyspark - Python Set Same Timezone

I am reading some parquet with timezone GMT-4
def get_spark():
spark = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.parquet.enableVectorizedReader", "false")
spark.conf.set("spark.sql.legacy.parquet.datetimeRebaseModeInRead", "LEGACY")
spark.conf.set("spark.sql.session.timeZone", "GMT-4")
return spark
File Show
base_so.where(base_so.ID_NUM_CLIENTE == 2273).show()
+--------------+-----------+----------------+------------------+-------------------+-------------------+-------------------+
|ID_NUM_CLIENTE|NUM_TRAMITE|COD_TIPO_1 |COD_TIPO_2 | FECHA_TRAMITE| FECHA_INGRESO| FECHA_INICIO_PAGO|
+--------------+-----------+----------------+------------------+-------------------+-------------------+-------------------+
| 2273| 238171| X| NN |2005-10-25 00:00:00|2005-10-25 09:26:54|1995-05-03 00:00:00|
| 2273| 238171| X| NMP|2005-10-25 00:00:00|2005-10-25 09:26:54|1995-05-03 00:00:00|
+--------------+-----------+----------------+------------------+-------------------+-------------------+-------------------+
when I create a dataframe from a test it does not leave me the date of the column
spark = get_spark()
df_busqueda = spark.createDataFrame(
data=[
[Decimal(2273), Decimal(238171), "SO", datetime.strptime('2005-10-25 00:00:00', '%Y-%m-%d %H:%M:%S')],
],
schema=StructType(
[
StructField('ID_NUM_CLIENTE', DecimalType(), True),
StructField('NUM_TRAMITE', DecimalType(), True),
StructField('COD_TIPO_1', StringType(), True),
StructField('FECHA_TRAMITE', TimestampType(), True),
]
),
)
+--------------+-----------+----------------+-------------------+
|ID_NUM_CLIENTE|NUM_TRAMITE|COD_TIPO_1 | FECHA_TRAMITE|
+--------------+-----------+----------------+-------------------+
| 2273| 238171| SO|2005-10-24 23:00:00|
+--------------+-----------+----------------+-------------------+
How can I better configure so that both the parquet and the dataframes created maintain the same timezone?
You can set the timezone in the spark session.
Example:
For spark > 3:
spark.sql("SET TIME ZONE 'America/New_York'").show()
//+--------------------------+----------------+
//|key |value |
//+--------------------------+----------------+
//|spark.sql.session.timeZone|America/New_York|
//+--------------------------+----------------+
spark.sql("select current_timestamp()").show()
//+--------------------------+
//|current_timestamp() |
//+--------------------------+
//|2021-08-25 16:23:16.096459|
//+--------------------------+
For spark < 3.0:
spark.conf.set("spark.sql.session.timeZone", "UTC")
spark.sql("select current_timestamp()").show()
//+--------------------+
//| current_timestamp()|
//+--------------------+
//|2021-08-25 20:26:...|
//+--------------------+
#Import packages
import os, time
from dateutil import tz
Format timestamp with the following snippet
os.environ['TZ'] = 'GMT+4'
time.tzset()
time.strftime('%X %x %Z')
In my case, the files were being uploaded via NIFI and I had to modify the bootstrap to the same TimeZone
spark-defaults.conf
spark.driver.extraJavaOptions -Duser.timezone=America/Santiago
spark.executor.extraJavaOptions -Duser.timezone=America/Santiago

Bitwise operations in pyspark, without using udf

I have spark dataframe as shown below :
+---------+---------------------------+
|country |sports |
+---------+---------------------------+
|India |[Cricket, Hockey, Football]|
|Sri Lanka|[Cricket, Football] |
+---------+---------------------------+
Each of the sport in the sports column is represented with a code :
sport_to_code_map = {
'Cricket' : 0x0001,
'Hockey' : 0x0002,
'Football' : 0x0004
}
Now I want to add a new column by the name sportsInt, which is the result of bitwise or of each of code associated with the sport string in the above map, thus resulting in :
+---------+---------------------------+---------+
|country |sports |sportsInt|
+---------+---------------------------+---------+
|India |[Cricket, Hockey, Football]|7 |
|Sri Lanka|[Cricket, Football] |5 |
+---------+---------------------------+---------+
I know one way to do this would be using UDF and it would be something like this :
def get_sport_to_code(sport_name):
sport_to_code_map = {
'Cricket': 0x0001,
'Hockey': 0x0002,
'Football': 0x0004
}
if feature not in sport_to_code_map:
raise Exception(f'Unknown Sport: {sport_name}')
return sport_to_code_map.get(sport_name)
def sport_to_code(sports):
if not sports:
return None
code = 0x0000
for sport in sports:
code = code | get_sport_to_code(sport)
return code
import pyspark.sql.functions as F
sport_to_code_udf = F.udf(sport_to_code, F.StringType())
df.withColumn('sportsInt',sport_to_code_udf('sports'))
But is there any way i could do this using spark functions? rather than udf?
From Spark-2.4+ we can use aggregate higher order function with bitwise or operator for this case.
Example:
from pyspark.sql.types import *
from pyspark.sql.functions import *
sport_to_code_map = {
'Cricket' : 0x0001,
'Hockey' : 0x0002,
'Football' : 0x0004
}
#creating dataframe from dictionary
lookup=spark.createDataFrame(*[zip(sport_to_code_map.keys(),sport_to_code_map.values())],["key","value"])
#sample dataframe
df.show(10,False)
#+---------+---------------------------+
#|country |sports |
#+---------+---------------------------+
#|India |[Cricket, Hockey, Football]|
#|Sri Lanka|[Cricket, Football] |
#+---------+---------------------------+
df1=df.selectExpr("explode(sports) as key","country")
df2=df1.join(lookup,['key'],'left').\
groupBy("country").\
agg(collect_list(col("key")).alias("sports"),collect_list(col("value")).alias("sportsInt"))
df2.withColumn("sportsInt",expr('aggregate(sportsInt,0,(s,x) -> int(s) | int(x))')).\
show(10,False)
#+---------+---------------------------+---------+
#|country |sports |sportsInt|
#+---------+---------------------------+---------+
#|Sri Lanka|[Cricket, Football] |5 |
#|India |[Cricket, Hockey, Football]|7 |
#+---------+---------------------------+---------+
If you want to avoid join for lookup in sport_to_code_map dict then use .replace:
#converting dict values to string
sport_to_code_map={k:str(v) for k,v in sport_to_code_map.items()}
df1.replace(sport_to_code_map).show()
#+---+---------+
#|key| country|
#+---+---------+
#| 1| India|
#| 2| India|
#| 4| India|
#| 1|Sri Lanka|
#| 4|Sri Lanka|
#+---+---------+

Get duplicates in collection list of dataframes in the same row in Apache Spark (pyspark 2.4)

In Spark, with pyspark, I have a data frame with duplicates. I want to deduplicate them with multiples rules like email and mobile_phone.
This is my code in python 3 :
from pyspark.sql import Row
from pyspark.sql.functions import collect_list
df = sc.parallelize(
[
Row(raw_id='1001', first_name='adam', mobile_phone='0644556677', email='adam#gmail.fr'),
Row(raw_id='2002', first_name='adam', mobile_phone='0644556688', email='adam#gmail.fr'),
Row(raw_id='3003', first_name='momo', mobile_phone='0644556699', email='momo#gmail.fr'),
Row(raw_id='4004', first_name='momo', mobile_phone='0644556600', email='mouma#gmail.fr'),
Row(raw_id='5005', first_name='adam', mobile_phone='0644556688', email='adama#gmail.fr'),
Row(raw_id='6006', first_name='rida', mobile_phone='0644556688', email='rida#gmail.fr')
]
).toDF()
My original dataframe is :
+--------------+----------+------------+------+
| email|first_name|mobile_phone|raw_id|
+--------------+----------+------------+------+
| adam#gmail.fr| adam| 0644556677| 1001|
| adam#gmail.fr| adam| 0644556688| 2002|
| momo#gmail.fr| momo| 0644556699| 3003|
|mouma#gmail.fr| momo| 0644556600| 4004|
|adama#gmail.fr| adam| 0644556688| 5005|
| rida#gmail.fr| rida| 0644556688| 6006|
+--------------+----------+------------+------+
Then, i apply my deduplication rules :
df_mobile = df \
.groupBy('mobile_phone') \
.agg(collect_list('raw_id').alias('raws'))
df_email = df \
.groupBy('email') \
.agg(collect_list('raw_id').alias('raws'))
This is the result i have :
df_mobile.select('raws').show(10, False)
+------------------+
|raws |
+------------------+
|[2002, 5005, 6006]|
|[1001] |
|[4004] |
|[3003] |
+------------------+
df_email.select('raws').show(10, False)
+------------+
|raws |
+------------+
|[3003] |
|[4004] |
|[1001, 2002]|
|[5005] |
|[6006] |
+------------+
So, the final result I want is to regroup common elements of the raws column like this :
+------------------------+
|raws |
+------------------------+
|[3003] |
|[4004] |
|[2002, 5005, 6006, 1001]|
+------------------------+
Do you know how I can do it with pyspark ?
Thank you very much!
So it seems as #pault is hinting at you could model this as a graph where your original dataframe df is a list of vertices and df_email and df_mobile are lists of connected vertices. Now unfortunately GraphX is not available for python, but GraphFrames is!
GrameFrames has a function called Connected Components that will return the list of connected raw_ids or vertices. To use it we must do two things, raw_id must be just called id and the edge must be source (src) and destination (dst) pairs not simply lists of vertices.
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from graphframes import GraphFrame
spark = SparkSession \
.builder \
.appName("example") \
.getOrCreate()
spark.sparkContext.setCheckpointDir("checkpoints")
# graphframes requires a checkpoint dir since v0.3.0
# https://graphframes.github.io/user-guide.html#connected-components
spark.sparkContext.setLogLevel("WARN") # make it easier to see our output
vertices = spark.createDataFrame([
('1001', 'adam', '0644556677', 'adam#gmail.fr'),
('2002', 'adam', '0644556688', 'adam#gmail.fr'),
('3003', 'momo', '0644556699', 'momo#gmail.fr'),
('4004', 'momo', '0644556600', 'mouma#gmail.fr'),
('5005', 'adam', '0644556688', 'adama#gmail.fr'),
('6006', 'rida', '0644556688', 'rida#gmail.fr')
]).toDF("id", "first_name", "mobile_phone", "email")
mk_edges = udf(
lambda a: [{'src': src, 'dst': dst} for (src, dst) in zip(a, a[-1:] + a[:-1])],
returnType=ArrayType(StructType([
StructField('src', StringType(), nullable=False),
StructField('dst', StringType(), nullable=False)])))
def edges_by_group_key(df, group_key):
return df.groupBy(group_key) \
.agg(collect_list('id').alias('ids')) \
.select(mk_edges('ids').alias('edges')) \
.select(explode('edges').alias('edge')) \
.select("edge.*")
mobileEdges = edges_by_group_key(vertices, 'mobile_phone')
print('mobile edges')
mobileEdges.show(truncate=False)
# mobile edges
# +----+----+
# |src |dst |
# +----+----+
# |2002|6006|
# |5005|2002|
# |6006|5005|
# |1001|1001|
# |4004|4004|
# |3003|3003|
# +----+----+
emailEdges = edges_by_group_key(vertices, 'email')
print('email edges')
emailEdges.show(truncate=False)
# email edges
# +----+----+
# |src |dst |
# +----+----+
# |3003|3003|
# |4004|4004|
# |1001|2002|
# |2002|1001|
# |5005|5005|
# |6006|6006|
# +----+----+
g = GraphFrame(vertices, mobileEdges.union(emailEdges))
result = g.connectedComponents()
print('connectedComponents')
result.select("id", "component") \
.groupBy("component") \
.agg(collect_list('id').alias('ids')) \
.select('ids').show(truncate=False)
# connectedComponents
# +------------------------+
# |ids |
# +------------------------+
# |[1001, 2002, 5005, 6006]|
# |[4004] |
# |[3003] |
# +------------------------+
There might be a cleverer way to do the union between the mobile and email dataframes, maybe deduplicate with distinct, but you get the idea.

Categories