- 01-利用Jupyter-Notebook启动PySpark
- 02-PySpark-INFO-LOG问题
- 03-PySpark调用关系型数据库表
- 04-PySpark过滤问题
- 05-PySpark列表循环
- 06-由RDD生成DataFrame
- 08-PySpark数据预处理
- 09-RDD的生成与转换
- 10-JSON格式嵌套问题
- 11-str转TimestampType()
- 12-字段去重计数:countDistinct和approxCountDistinct
- 13-PySpark新增列问题
- 14-PySpark数据填充、过滤
- 15-Renaming-Columns-for-PySpark-DataFrames-Aggregates
- 16-保存Pyspark-dataframe到Hbase
- 17-新增列时when-otherwise的使用
- 18-PySpark-DataFrame元组列表的解压缩
- 19-Spark—shuffle-write
结合项目中遇到的问题,并参考https://stackoverflow.com/的问题解答,将PySpark开发过程中的问题进行梳理如下欢迎批评指正!!! 后续若有问题会继续更新。。。
01-利用Jupyter-Notebook启动PySpark
推荐先安装好Anaconda
习惯了在Jupyter Notebook上面测试代码,怎么在使用PySpark时也能够利用上,只需要一下几步骤!!!
修改 ~/.bashrc
文件
nano ~/.bashrc
添加以下内容
export PATH="/opt/anaconda3/bin:$PATH"
export PYSPARK_PYTHON="python3"
export PYSPARK_DRIVER_PYTHON="ipython3"
export PYSPARK_DRIVER_PYTHON_OPTS="notebook --NotebookApp.open_browser=False --allow-root --NotebookApp.ip='192.168.111.106' --NotebookApp.port=8889"
source一下
source ~/.bashrc
就可以让Jupyter Notebook
启动 pyspark
[root@server106 pyspark_app]# pyspark
[TerminalIPythonApp] WARNING | Subcommand `ipython notebook` is deprecated and will be removed in future versions.
[TerminalIPythonApp] WARNING | You likely want to use `jupyter notebook`... continue in 5 sec. Press Ctrl-C to quit now.
[I 16:27:26.624 NotebookApp] Serving notebooks from local directory: /home/ydzhao/pyspark_app
[I 16:27:26.624 NotebookApp] 0 active kernels
[I 16:27:26.624 NotebookApp] The Jupyter Notebook is running at: http://192.168.111.106:8889/?token=b7a226911d027509bf277fe40641b3572608ae47b26983f7
[I 16:27:26.624 NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).
[C 16:27:26.625 NotebookApp]
Copy/paste this URL into your browser when you connect for the first time,
to login with a token:
http://192.168.111.106:8889/?token=b7a226911d027509bf277fe40641b3572608ae47b26983f7
02-PySpark-INFO-LOG问题
Spark 1.6.2
log4j = sc._jvm.org.apache.log4j
log4j.LogManager.getRootLogger().setLevel(log4j.Level.ERROR)
# 设置Spark控制台不要显示太多信息
def SetLogger(sc):
logger = sc._jvm.org.apache.log4j
logger.LogManager.getLogger("org"). setLevel( logger.Level.ERROR )
logger.LogManager.getLogger("akka").setLevel( logger.Level.ERROR )
logger.LogManager.getLogger("INFO"). setLevel( logger.Level.ERROR )
logger.LogManager.getLogger("WARN"). setLevel( logger.Level.ERROR )
logger.LogManager.getRootLogger().setLevel(logger.Level.ERROR)
Spark 2.0+
spark.sparkContext.setLogLevel('WARN')
旧的方法
conf/log4j.properties.template 复制为 conf/log4j.properties
编辑log4j.properties
log4j.rootCategory=INFO, console 修改为 log4j.rootCategory=WARN, console
不同的LOG设置:
OFF (most specific, no logging)
FATAL (most specific, little data)
ERROR - Log only in case of Errors
WARN - Log only in case of Warnings or Errors
INFO (Default)
DEBUG - Log details steps (and all logs stated above)
TRACE (least specific, a lot of data)
ALL (least specific, all data)
pyspark.SparkContext.setLogLevel
from pyspark.sql import SparkSession
spark = SparkSession.builder.\
master('local').\
appName('foo').\
getOrCreate()
spark.sparkContext.setLogLevel('WARN')
03-PySpark调用关系型数据库表
启动pyspark shell时,如果想调用数据库表,可以这样操作
需要提前下载好ojdbc6-11.2.0.3.jar
包
pyspark --jars "/data/spark/ojdbc6-11.2.0.3.jar"
提交任务时加上
spark-submit --master yarn \
--deploy-mode cluster \
--num-executors 25 \
--executor-cores 2 \
--driver-memory 4g \
--executor-memory 4g \
--conf spark.broadcast.compress=true \
--jars "/data/spark/ojdbc6-11.2.0.3.jar" /home/project/test.py
04-PySpark过滤问题
使用boolean expressions
'&' for 'and'
'|' for 'or'
'~' for 'not'
‘~’ for ‘not’
rdd = sc.parallelize([(0,1), (0,1), (0,2), (1,2), (1,10), (1,20), (3,18), (3,18), (3,18)])
df = sqlContext.createDataFrame(rdd, ["id", "score"])
l = [10,18,20]
df.filter(~df.score.isin(l))
df.where(df.score.isin(l))
输出:
(0,1), (0,1), (0,2), (1,2)
(1,10), (1,20), (3,18), (3,18), (3,18)
‘&’ for ‘and’
‘|’ for ‘or’
df1.filter((df1.LON == 0)|(df1.LAT == 0)).count()
24105
df1.filter((df1.LON == 0)&(df1.LAT == 0)).count()
14252
df1.filter(~df1.LON == 0).count()
AnalysisException: "cannot resolve 'NOT LON' due to data type mismatch:
argument 1 requires boolean type, however, 'LON' is of double type.;"
df1.filter(~df1.LON.isNotNull()).count()
0
05-PySpark列表循环
df.select([c for c in df.columns if c != 'id'])
#(缺失值个数,行号)
df_miss.rdd.map(lambda row:(sum([c == None for c in row]),row['id'])).sortByKey(ascending=False).collect()
df_miss.agg(*[(1.00 - (fn.count(c) / fn.count('*'))).alias(c + '_missing') for c in df_miss.columns]).show()
df_miss.select([c for c in df_miss.columns if c != 'income'])
df_miss_no_income.agg(*[fn.mean(c).alias(c) for c in df_miss_no_income.columns if c != 'gender']).\
toPandas().to_dict('records')[0]
06-由RDD生成DataFrame
# RDD到DataFrame
from pyspark.sql import types
fraud = sc.textFile('file:///root/ydzhao/PySpark/Chapter04/ccFraud.csv.gz')
header = fraud.first()
header
'"custID","gender","state","cardholder","balance","numTrans","numIntlTrans","creditLine","fraudRisk"'
fraud.take(3)
['"custID","gender","state","cardholder","balance","numTrans","numIntlTrans","creditLine","fraudRisk"',
'1,1,35,1,3000,4,14,2,0',
'2,2,2,1,0,9,0,18,0']
fraud.count()
10000001
(1) 去除标题行数据,每个元素转换成整型Integer,还是RDD
fraud = fraud.filter(lambda row: row != header).map(lambda row: [int(x) for x in row.split(',')])
fraud.take(3)
[[1, 1, 35, 1, 3000, 4, 14, 2, 0],
[2, 2, 2, 1, 0, 9, 0, 18, 0],
[3, 2, 2, 1, 0, 27, 9, 16, 0]]
(2) 创建DataFrame模式
h[1:-1]
代表第一行到最后一行
schema = [
*[
types.StructField(h[1:-1], types.IntegerType(), True) for h in header.split(',')
]
]
schema = types.StructType(schema)
(3) 创建DataFrame
# spark2.0+
# fraud_df = spark.createDataFrame(fraud, schema)
# spark1.6.2
fraud_df = sqlContext.createDataFrame(fraud, schema)
fraud_df.printSchema()
root
|-- custID: integer (nullable = true)
|-- gender: integer (nullable = true)
|-- state: integer (nullable = true)
|-- cardholder: integer (nullable = true)
|-- balance: integer (nullable = true)
|-- numTrans: integer (nullable = true)
|-- numIntlTrans: integer (nullable = true)
|-- creditLine: integer (nullable = true)
|-- fraudRisk: integer (nullable = true)
08-PySpark数据预处理
RDD读取数据、RDD动作操作、RDD创建DF、读Oracle数据库、新增一列,转换数据类型,过滤,join、做聚合、去重、排序(基于Spark1.6.2)
rdd = sc.textFile('/sh/signaling/2016/10/22/TRAFF_20161022232800.txt.lzo')
rdd1 = rdd.map(lambda x:x.split(",")).map(lambda line:(line[2]+line[3],1)).reduceByKey(lambda x,y:x+y)
df1 = sqlContext.createDataFrame(rdd1,['LACCELL','NUM'])
df = sqlContext.read.format("jdbc").options(url="jdbc:oracle:thin:@192.168.111.107:1521:orcl"
, driver = "oracle.jdbc.driver.OracleDriver"
, dbtable = "TBL_LAC_CELL_TAZ"
, user="shanghai_base"
, password="shanghai_base").load()
from pyspark.sql import types
df2 = df.\
withColumn('LACCELL',df.LAC_CELL.cast(types.LongType())).\
select('LACCELL','TAZID').filter(df.TAZID>100000)
df3 = df1.\
join(df2,df1.LACCELL == df2.LACCELL,'right').\
select(df2.TAZID.alias('GRID_1KM'),df1.LACCELL,df1.NUM).\
dropna(how='any').\
orderBy(["GRID_1KM", "LACCELL"], ascending=[1, 0])
df3.show()
09-RDD的生成与转换
# 1 txt读取生成RDD和转换保存
rdd = sc.textFile('/nj/signaling/2016/07/11/TRAFF_20160711235500.txt')
rdd_save = sc.textFile('/nj/signaling/2016/07/11').\
filter(lambda line: int(line.split(",")[0], 16) % 2 ==0).\
saveAsTextFile("/nj/signaling2/2016/07/11")
# 2 自定义函数读取RDD
file=sc.textFile('/GPS/FeiTian/2017/07/*')
def line2record(line) :
segs = line.split(" ")
return (segs[0],segs[1],segs[2],segs[5])
file.map(line2record).take(10)
[('2017-07-27', '23:59:57', '21582', '0'),
('2017-07-27', '23:59:58', '133386', '0'),
('2017-07-27', '22:59:00', '130387', '0'),
('2017-07-27', '23:59:57', '125899', '0'),
('2017-07-27', '22:59:58', '142358', '0'),
('2017-07-27', '23:59:58', '110065', '0'),
('2017-07-27', '22:59:58', '136810', '0'),
('2017-07-27', '23:59:58', '139889', '1'),
('2017-07-27', '23:59:57', '19877', '2'),
('2017-07-27', '23:56:54', '32764', '0')]
# 3 自定义字段,方便字段组合生成新字段
rdd = sc.textFile('/GPS/FeiTian/2017/07/201707302338.dt')
rdd_split = rdd.map(lambda x:(x.split(" ")[2],
(x.split(" ")[0]+" "+x.split(" ")[1],
x.split(" ")[3],
x.split(" ")[4],
x.split(" ")[5],
x.split(" ")[6],
x.split(" ")[7],
x.split(" ")[8],
)))
rdd_split.take(3)
[('33712',
('2017-07-30 23:37:55', '121.74655', '31.051933', '0', '0', '0.0', '05')),
('135249',
('2017-07-30 23:37:56', '121.4169', '31.2813', '0', '6', '11.0', '02')),
('11292',
('2017-07-30 23:37:20', '121.5148', '31.1553', '0', '2', '0.0', '02'))
]
rdd_groupbykey = rdd_split.groupByKey().map(lambda x : (x[0], list(x[1])))
rdd_groupbykey.take(5)
[('21664',
[('2017-08-30 23:38:29', '121.77598', '30.958338', '0', '4', '0.0', '04')]),
('1114',
[('2017-07-30 23:38:33', '121.58', '31.3436', '1', '7', '0.0', '02')]),
('172072',
[('2017-07-30 23:38:39', '121.4904', '31.351997', '1', '2', '0.0', '05')]),
('160125',
[('2017-07-30 23:38:00', '121.344795', '31.19619', '0', '6', '0.0', '05'),
('2017-07-30 23:38:40', '121.344795', '31.19619', '0', '6', '0.0', '05')]),
('135655',
[('2017-07-30 23:38:28', '121.4507', '31.3406', '1', '0', '0.0', '02')])
]
rdd_groupbykey.map(lambda x:x[1]).take(5)
[[('2017-08-30 23:38:29', '121.77598', '30.958338', '0', '4', '0.0', '04')],
[('2017-07-30 23:38:33', '121.58', '31.3436', '1', '7', '0.0', '02')],
[('2017-07-30 23:38:39', '121.4904', '31.351997', '1', '2', '0.0', '05')],
[('2017-07-30 23:38:00', '121.344795', '31.19619', '0', '6', '0.0', '05'),
('2017-07-30 23:38:40', '121.344795', '31.19619', '0', '6', '0.0', '05')],
[('2017-07-30 23:38:28', '121.4507', '31.3406', '1', '0', '0.0', '02')]
]
# 4 自定义字段,方便字段组合生成新字段
rdd = sc.textFile('file:///root/ydzhao/Two_passengers_and_one_danger/20171219_20171219135232.txt',4)
from datetime import datetime
rdd_split= rdd.map(lambda x:(datetime.strptime(x.split("@@")[0], "%Y-%m-%d %H:%M:%S"),
x.split("@@")[1],
x.split("@@")[2],
x.split("@@")[3],
int(x.split("@@")[4])*0.000001,
int(x.split("@@")[5])*0.000001,
int(x.split("@@")[6]),
int(x.split("@@")[7]),
int(x.split("@@")[8]),
x.split("@@")[9],
int(x.split("@@")[10]),
x.split("@@")[11],
x.split("@@")[12]))
from pyspark.sql.types import *
schema = StructType([
StructField("TimeStamp",TimestampType(),True),
StructField("VehicleID",StringType(),True),
StructField("VehiclePlateColor",StringType(),True),
StructField("MessageSeq",StringType(),True),
StructField("Lng",DoubleType(),True),
StructField("Lat",DoubleType(),True),
StructField("TerminalSpeed",IntegerType(),True),
StructField("DrivingSpeed",IntegerType(),True),
StructField("TotalMile",IntegerType(),True),
StructField("Direction",StringType(),True),
StructField("Altitude",IntegerType(),True),
StructField("StatusBit",StringType(),True),
StructField("AlarmStatus",StringType(),True)
])
df2 = sqlContext.createDataFrame(rdd_split,schema)
df2.registerTempTable("df2")
10-JSON格式嵌套问题
解决JSON格式嵌套问题
df = sqlContext.read.json('file:///home/UnicomGSM/data/20170614103000.json')
df.printSchema()
root
|-- RoadSegState: struct (nullable = true)
| |-- DateTimeDelay: long (nullable = true)
| |-- Datetime: string (nullable = true)
| |-- Description: string (nullable = true)
| |-- IntersectionDelay: long (nullable = true)
| |-- IsRoadIntersection: string (nullable = true)
| |-- MobileNumber: long (nullable = true)
| |-- Number: long (nullable = true)
| |-- RoadSegID: string (nullable = true)
| |-- SigNumber: long (nullable = true)
| |-- Speed: double (nullable = true)
| |-- SpeedDiff: double (nullable = true)
| |-- State: string (nullable = true)
| |-- Time: double (nullable = true)
df1 = df.select([df.RoadSegState.DateTimeDelay.alias("DateTimeDelay"),
df.RoadSegState.DateTime.alias("DateTime"),
df.RoadSegState.Description.alias("Description"),
df.RoadSegState.IntersectionDelay.alias("IntersectionDelay"),
df.RoadSegState.IsRoadIntersection.alias("IsRoadIntersection"),
df.RoadSegState.MobileNumber.alias("MobileNumber"),
df.RoadSegState.Number.alias("Number"),
df.RoadSegState.RoadSegID.alias("RoadSegID"),
df.RoadSegState.SigNumber.alias("SigNumber"),
df.RoadSegState.Speed.alias("Speed"),
df.RoadSegState.SpeedDiff.alias("SpeedDiff"),
df.RoadSegState.State.alias("State"),
df.RoadSegState.Time.alias("Time")])
df1.printSchema()
root
|-- DateTimeDelay: long (nullable = true)
|-- DateTime: string (nullable = true)
|-- Description: string (nullable = true)
|-- IntersectionDelay: long (nullable = true)
|-- IsRoadIntersection: string (nullable = true)
|-- MobileNumber: long (nullable = true)
|-- Number: long (nullable = true)
|-- RoadSegID: string (nullable = true)
|-- SigNumber: long (nullable = true)
|-- Speed: double (nullable = true)
|-- SpeedDiff: double (nullable = true)
|-- State: string (nullable = true)
|-- Time: double (nullable = true)
11-str转TimestampType()
import pyspark.sql.functions as func
from pyspark.sql.types import TimestampType
from datetime import datetime
df_y = sqlContext.read.json("/user/test.json")
udf_dt = func.udf(lambda x: datetime.strptime(x, '%Y%m%d%H%M%S'), TimestampType())
df = df_y.withColumn('datetime', udf_dt(df_y.date))
df_g = df_y.groupby(func.hour(df_y.date))
df_y.groupby(df_y.name).agg(func.countDistinct('address')).show()
12-字段去重计数:countDistinct和approxCountDistinct
精确计数countDistinct
from pyspark.sql.functions import col, countDistinct
df.agg(*(countDistinct(col(c)).alias(c) for c in df.columns))
from pyspark.sql.functions import col, countDistinct
df.agg(countDistinct(col("colName")).alias("count")).show()
from pyspark.sql.functions import countDistinct
df1.agg(countDistinct(df1.msid).alias('count')).collect()
大概计数approxCountDistinct
如果你想加快速度,可能会损失精度,可以使用approxCountDistinct()
from pyspark.sql.functions import approxCountDistinct
df1.agg(approxCountDistinct(df1.msid).alias('count')).collect()
显然countDistinct不如approxCountDistinct性能好
13-PySpark新增列问题
from pyspark.sql.functions import col, expr, when
方法一
new_column_1 = expr(
"""IF(fruit1 IS NULL OR fruit2 IS NULL, 3, IF(fruit1 = fruit2, 1, 0))"""
)
方法二:思路清晰,推荐使用
new_column_2 = when(
col("fruit1").isNull() | col("fruit2").isNull(), 3
).when(col("fruit1") == col("fruit2"), 1).otherwise(0)
方法三:用的少
from pyspark.sql.functions import coalesce, lit
new_column_3 = coalesce((col("fruit1") == col("fruit2")).cast("int"), lit(3))
df = sc.parallelize([
("orange", "apple"), ("kiwi", None), (None, "banana"),
("mango", "mango"), (None, None)
]).toDF(["fruit1", "fruit2"])
df.withColumn("new_column_1", new_column_1).\
withColumn("new_column_2", new_column_2).\
withColumn("new_column_3", new_column_3)
+------+------+------------+------------+------------+
|fruit1|fruit2|new_column_1|new_column_2|new_column_3|
+------+------+------------+------------+------------+
|orange| apple| 0| 0| 0|
| kiwi| null| 3| 3| 3|
| null|banana| 3| 3| 3|
| mango| mango| 1| 1| 1|
| null| null| 3| 3| 3|
+------+------+------------+------------+------------+
方法四 : 自定义函数
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf
def func(fruit1, fruit2):
if fruit1 == None or fruit2 == None:
return 3
if fruit1 == fruit2:
return 1
return 0
func_udf = udf(func, IntegerType())
df = df.withColumn('new_column',func_udf(df['fruit1'], df['fruit2']))
14-PySpark数据填充、过滤
id Value
1 103
2 1504
3 1
from pyspark.sql.functions import lpad
>>> df.select('id',lpad(df['value'],4,'0').alias('value')).show()
+---+-----+
| id|value|
+---+-----+
| 1| 0103|
| 2| 1504|
| 3| 0001|
+---+-----+
# filter/where
from pyspark.sql.functions import col
df.where(col("dt_mvmt").isNull())
df.where(col("dt_mvmt").isNotNull())
df.na.drop(subset=["dt_mvmt"])
df.filter(df.dt_mvmt.isNotNull()).count()
col_list = df.schema.names
df_fltered = df.where(col(c) >= 10 for c in col_list)
15-Renaming-Columns-for-PySpark-DataFrames-Aggregates
# 方法一:需要提前知道agg后的列名SUM(money),这个不好
df.groupBy("group")\
.agg({"money":"sum"})\
.withColumnRenamed("SUM(money)", "money")
.show(100)
# 方法二:推荐使用
import pyspark.sql.functions as func
df.groupBy("group")\
.agg(func.sum('money').alias('money'))\
.show(100)
#####
cols = [i.name for i in df.schema.fields if "StructType" in str(i.dataType)]
df.select(cols)
#####
df = sc.\
parallelize([(1,'female',233), (None,'female',314),(0,'female',81),(1, None, 342), (1,'male',109)]).\
toDF().\
withColumnRenamed("_1","survived").\
withColumnRenamed("_2","sex").\
withColumnRenamed("_3","count")
total = df.select("count").agg(sum('count').alias('sum_count')).collect().pop()['sum_count']
result = df.withColumn('percent', (df['count']/total) * 100)
result.show()
+--------+------+-----+------------------+
|survived| sex|count| percent|
+--------+------+-----+------------------+
| 1|female| 233| 21.59406858202039|
| null|female| 314|29.101019462465246|
| 0|female| 81| 7.506950880444857|
| 1| null| 342| 31.69601482854495|
| 1| male| 109|10.101946246524559|
+--------+------+-----+------------------+
df = sc.parallelize([
("XYZ12", "B1, B44, B66", "A, C", 59),
("ZYY3 ", "B8, B3, B7", "J, Z", 66)
]).toDF(["dbn", "bus", "subway", "score"])
from pyspark.sql.functions import col, explode, split, trim
with_bus_exploded = df.withColumn("bus", explode(split("bus", ",")))
with_bus_trimmed = with_bus_exploded.withColumn("bus", trim(col("bus")))
+-----+---+------+-----+
| dbn|bus|subway|score|
+-----+---+------+-----+
|XYZ12| B1| A, C| 59|
|XYZ12|B44| A, C| 59|
|XYZ12|B66| A, C| 59|
|ZYY3 | B8| J, Z| 66|
|ZYY3 | B3| J, Z| 66|
|ZYY3 | B7| J, Z| 66|
+-----+---+------+-----+
16-保存Pyspark-dataframe到Hbase
df = sc.parallelize([('a', '1.0'), ('b', '2.0')]).toDF(schema=['col0', 'col1'])
catalog = ''.join("""{
"table":{"namespace":"default", "name":"testtable"},
"rowkey":"key",
"columns":{
"col0":{"cf":"rowkey", "col":"key", "type":"string"},
"col1":{"cf":"cf", "col":"col1", "type":"string"}
}
}""".split())
# write to hbase
df.write \
.options(catalog=catalog) \
.format('org.apache.spark.sql.execution.datasources.hbase') \
.mode("overwrite") \
.option("zkUrl","host1,host2,host3:2181") \
.save()
# reading
df_read = spark.read.options(catalog=catalog).format('org.apache.spark.sql.execution.datasources.hbase').load()
df_read .show()
"PYSPARK_SUBMIT_ARGS": "--master yarn \
--jars hbase_spark_jar/hbase-0.90.2.jar,\
/hbase_spark_jar/hbase-client-1.3.1.jar,\
hbase_spark_jar/spark-avro_2.11-3.0.1.jar,\
/hbase_spark_jar/hbase-spark-1.2.0-cdh5.7.3.jar,\
/hbase_spark_jar/shc-1.0.0-2.0-s_2.11.jar \
--files /etc/hbase/2.5.0.0-1245/0/hbase-site.xml \
--executor-memory 8g \
--executor-cores 4 \
--num-executors 4 \
pyspark-shell"
17-新增列时when-otherwise的使用
when otherwise
类似于SQL中case when end
的作用
df = [
**id** **col1** **col2** **col3** **col4**
101 1 0 1 1
102 0 1 1 0
103 1 1 0 1
104 0 0 1 1
]
from pyspark.sql.functions import when, lit
def update_col_check(df, col_name):
return df.withColumn('col_check', when(df[col_name] == 1, lit(col_name)).otherwise(df['col_check']))
update_col_check(df, 'col1').show()
+---+----+----+----+----+---------+
| id|col1|col2|col3|col4|col_check|
+---+----+----+----+----+---------+
|101| 1| 0| 1| 1| col1|
|102| 0| 1| 1| 0| -1|
|103| 1| 1| 0| 1| col1|
|104| 0| 0| 1| 1| -1|
+---+----+----+----+----+---------+
update_col_check(df, 'col2').show()
+---+----+----+----+----+---------+
| id|col1|col2|col3|col4|col_check|
+---+----+----+----+----+---------+
|101| 1| 0| 1| 1| -1|
|102| 0| 1| 1| 0| col2|
|103| 1| 1| 0| 1| col2|
|104| 0| 0| 1| 1| -1|
+---+----+----+----+----+---------+
from pyspark.sql.functions import when
df.select(df.name, when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+-----+------------------------------------------------------------+
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+-----+------------------------------------------------------------+
|Alice| -1|
| Bob| 1|
+-----+------------------------------------------------------------+
18-PySpark-DataFrame元组列表的解压缩
from pyspark.sql.functions import udf, col
# create the dataframe
df = sqlContext.createDataFrame(
[
(1, [('blue', 0.5),('red', 0.1),('green', 0.7)]),
(2, [('red', 0.9),('cyan', 0.5),('white', 0.4)])
],
('Topic', 'Tokens')
)
df.show()
+-----+-------------------------------------------+
|Topic| Tokens |
+-----+-------------------------------------------+
| 1| ('blue', 0.5),('red', 0.1),('green', 0.7)|
| 2| ('red', 0.9),('cyan', 0.5),('white', 0.4)|
+-----+-------------------------------------------+
def get_colors(l):
return [x[0] for x in l]
def get_weights(l):
return [x[1] for x in l]
# make udfs from the above functions - Note the return types
get_colors_udf = udf(get_colors, ArrayType(StringType()))
get_weights_udf = udf(get_weights, ArrayType(FloatType()))
# use withColumn and apply the udfs
df.withColumn('Weights', get_weights_udf(col('Tokens')))\
.withColumn('Tokens', get_colors_udf(col('Tokens')))\
.select(['Topic', 'Tokens', 'Weights'])\
.show()
+-----+------------------+---------------+
|Topic| Tokens| Weights|
+-----+------------------+---------------+
| 1|[blue, red, green]|[0.5, 0.1, 0.7]|
| 2|[red, cyan, white]|[0.9, 0.5, 0.4]|
+-----+------------------+---------------+
root
|-- Topic: long (nullable = true)
|-- Tokens: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- _1: string (nullable = true)
| | |-- _2: double (nullable = true)
from pyspark.sql.functions import col
df.select(
col("Topic"),
col("Tokens._1").alias("Tokens"), col("Tokens._2").alias("weights")
).show()
+-----+------------------+---------------+
|Topic| Tokens| weights|
+-----+------------------+---------------+
| 1|[blue, red, green]|[0.5, 0.1, 0.7]|
| 2|[red, cyan, white]|[0.9, 0.5, 0.4]|
+-----+------------------+---------------+
cols = [
col("Tokens.{}".format(n)) for n in
df.schema["Tokens"].dataType.elementType.names]
df.select("Topic", *cols)
19-Spark—shuffle-write
什么时候需要 shuffle writer
我们抽象出来其中的RDD和依赖关系
对应的划分后的RDD结构为:
最终我们得到了整个执行过程:
中间就涉及到shuffle
过程,前一个stage
的ShuffleMapTask
进行 shuffle write
, 把数据存储在 blockManager
上面, 并且把数据位置元信息上报到driver
的mapOutTrack
组件中, 下一个stage
根据数据位置元信息, 进行shuffle read
, 拉取上个stage
的输出数据。
UnsafeShuffleWriter
UnsafeShuffleWriter
里面维护着一个 ShuffleExternalSorter
, 用来做外部排序,外部排序就是要先部分排序数据并把数据输出到磁盘,然后最后再进行merge
全局排序, 既然这里也是外部排序,跟SortShuffleWriter
有什么区别呢, 这里只根据 record
的 artition id
先在内存 ShuffleInMemorySorter
中进行排序, 排好序的数据经过序列化压缩输出到换一个临时文件的一段,并且记录每个分区段的seek
位置,方便后续可以单独读取每个分区的数据,读取流经过解压反序列化,就可以正常读取了。
整个过程就是不断地在ShuffleInMemorySorter
插入数据,如果没有内存就申请内存,如果申请不到内存就spill
到文件中,最终合并成一个 依据partition id
全局有序 的大文件。