by 司马顿 | 2021年12月9日 下午7:42
从mysql数据库读取记录,放到spark里进行分布式计算,是生产场景中常见的操作。
首先通过一个案例来展示,如何进行这种操作。我们假设有如下mysql表:
mysql> desc fruit;
+--------+-------------+------+-----+---------+----------------+
| Field | Type | Null | Key | Default | Extra |
+--------+-------------+------+-----+---------+----------------+
| id | int(11) | NO | PRI | NULL | auto_increment |
| name | varchar(32) | YES | | NULL | |
| number | int(11) | YES | | NULL | |
+--------+-------------+------+-----+---------+----------------+
3 rows in set (0.00 sec)
mysql> select * from fruit limit 5;
+----+-----------+--------+
| id | name | number |
+----+-----------+--------+
| 1 | peach | 1 |
| 2 | apricot | 2 |
| 3 | apple | 3 |
| 4 | haw | 1 |
| 5 | persimmon | 9 |
+----+-----------+--------+
5 rows in set (0.00 sec)
这个表很简单,每行一个记录,用来记录水果的名字和数量。我们要对水果进行分组,并按组统计每种水果的数量。如果使用SQL查询就很简单:
mysql> select name,sum(number) from fruit group by name;
+-----------------------+-------------+
| name | sum(number) |
+-----------------------+-------------+
| apple | 86 |
| apricot | 81 |
| areca nut | 139 |
| banana | 134 |
| bitter orange | 117 |
| blackberry | 121 |
| blueberry | 166 |
| cherry | 73 |
| coconut | 86 |
| crab apple | 75 |
| cumquat | 145 |
| flat peach | 125 |
| grape | 92 |
| greengage | 134 |
| guava | 82 |
| haw | 137 |
但现在要求使用spark,这主要是基于spark的分布式计算考虑,它可以将算力分摊到多个节点上,并行运行计算。在数据量非常庞大时,spark的性能优势是mysql无法取代的。spark有个dataframe数据结构,可以加载来自mysql的数据。dataframe有一系列高阶api,方便进行统计。
具体执行操作,假设你已经安装了spark,并且启动了spark的master和worker进程。接着下载mysql的JDBC驱动,这个文件在mysql官网下即可。将这个驱动文件,比如我这里是mysql-connector-java-8.0.27.jar,放到spark安装目录的jars文件夹下。然后以如下方式启动pyspark:
$ pyspark --jars mysql-connector-java-8.0.27.jar
进入到pyspark shell后,使用如下语法加载mysql数据:
mysql_df = (spark
.read
.format("jdbc")
.option("url", "jdbc:mysql://127.0.0.1:3306/spark")
.option("driver", "com.mysql.jdbc.Driver")
.option("dbtable", "fruit")
.option("user", "spark")
.option("password", "***")
.load())
上述语句里,第一句的spark是个实例化的spark session对象,登陆进入pyspark后,这个对象就自动创建了,可以直接使用。spark.read()表示读文件,它带的format()方法指定文件格式为JDBC,后面的option()系列方法表示数据库连接参数。最后的load()方法表示加载文件。
spark.read()可以读取很多种外部数据源,比如Mysql、Postgresql、JSON、CSV、Hive、Snowflake等等。表示方法都差不多。读取的内容位于mysql_df这个dataframe里。如下我们可以查看这个dataframe的相关属性。
>>> mysql_df.printSchema()
root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- number: integer (nullable = true)
>>> mysql_df.show(3)
+---+-------+------+
| id| name|number|
+---+-------+------+
| 1| peach| 1|
| 2|apricot| 2|
| 3| apple| 3|
+---+-------+------+
only showing top 3 rows
如何基于这个dataframe进行统计呢?很简单,因为dataframe有好多高阶api,可以直接操作数据,比如group、filter、sum、order等。我们运行如下语句进行group和sum的统计。
>>> from pyspark.sql.functions import *
>>> mysql_df.groupBy("name").agg(sum("number")).show()
+-------------+-----------+
| name|sum(number)|
+-------------+-----------+
| strawberry| 116|
| orange| 104|
| haw| 137|
| grape| 92|
| apple| 86|
| mango| 102|
| crab apple| 75|
| flat peach| 125|
| tomato| 102|
| jackfruit| 92|
| greengage| 134|
| apricot| 81|
| honey peach| 82|
| cherry| 73|
| pear| 89|
|bitter orange| 117|
| pomegranate| 84|
| starfruit| 80|
| raspberry| 133|
| banana| 134|
+-------------+-----------+
only showing top 20 rows
这就可以了,得到的结果跟mysql的group、sum是一样的(未考虑排序)。
也可以用spark原生的针对RDD的低阶api来完成这个统计操作,低阶api就是map/reduce那几个方法。
首先要构建一个RDD对象,里面包含了数据内容。然后,运用map/reduce相关方法,操作这个RDD。简单演示如下。
>>> rdd = sc.parallelize([("apple", 2), ("orange", 5), ("lemon", 3),("apple", 3), ("orange", 1)])
>>> rdd.reduceByKey(lambda x,y:x+y).collect()
[('orange', 6), ('apple', 5), ('lemon', 3)]
这个RDD操作也是如此简单,直接用reduceByKey()就分类汇总了。但是,如果涉及到复杂的计算,比如要求每个水果种类出现的平均个数,那么低阶api就写起来很烦。按组求平均的RDD操作如下。
>>> rdd.mapValues(lambda x:(x,1)).reduceByKey(lambda x,y:(x[0]+y[0],x[1]+y[1])).mapValues(lambda x:x[0]/x[1]).collect()
[('orange', 3.0), ('apple', 2.5), ('lemon', 3.0)]
即使我很熟悉函数式编程,上述语法写起来也有些绕。然而,用dataframe的高阶api实现这种分组求平均的目的,就非常简单了。如下首先创建一个简单dataframe对象,接着对这个对象执行高阶api操作。
>>> df = spark.createDataFrame([("apple", 2), ("orange", 5), ("lemon", 3),("apple", 3), ("orange", 1)],["name","number"])
>>> df.groupBy("name").agg(avg("number")).show()
+------+-----------+
| name|avg(number)|
+------+-----------+
|orange| 3.0|
| apple| 2.5|
| lemon| 3.0|
+------+-----------+
如上对比可以看出,高阶api相对低阶api,真要简单太多。当然,除了高阶api操作dataframe,还可以把dataframe转成spark自己的数据表,然后对这个数据表执行SQL查询。也就是说,spark同样支持SQL api,大家可以执行熟悉的SQL查询操作。演示如下,结果跟在数据库里操作是一样的。
>>> mysql_df.write.saveAsTable("fruit_table")
>>> spark.sql("select name,sum(number) from fruit_table group by name").show()
+-------------+-----------+
| name|sum(number)|
+-------------+-----------+
| strawberry| 116|
| orange| 104|
| haw| 137|
| grape| 92|
| apple| 86|
| mango| 102|
| crab apple| 75|
| flat peach| 125|
| tomato| 102|
| jackfruit| 92|
| greengage| 134|
| apricot| 81|
| honey peach| 82|
| cherry| 73|
| pear| 89|
|bitter orange| 117|
| pomegranate| 84|
| starfruit| 80|
| raspberry| 133|
| banana| 134|
+-------------+-----------+
only showing top 20 rows
总结:我们可以使用RDD低阶api、dataframe高阶api、SQL api来操作spark里的数据,看大家熟悉哪个就用哪个。如果使用Java/Scala语言,还有个dataset api可以操作数据。对统计工作而言,spark的这些api真的是很有帮助。
Source URL: https://smart.postno.de/archives/3521
Copyright ©2025 司马顿的博客 unless otherwise noted.