黑马程序员技术交流社区

标题: 【深圳校区】Spark 2.4.0编程指南--Spark SQL UDF和UDAF [打印本页]

作者: 柠檬leung不酸    时间: 2019-1-3 10:17
标题: 【深圳校区】Spark 2.4.0编程指南--Spark SQL UDF和UDAF
Spark 2.4.0编程指南--Spark SQL UDF和UDAF

更多资源视频

<iframe width="800" height="500" src="//player.bilibili.com/player.html?aid=38193405&cid=67137841&page=4" scrolling="no" border="0" frameborder="no" framespacing="0" allowfullscreen="true"> </iframe>

文档前置条件技能标签UDF

用户定义函数(User-defined functions, UDFs)是大多数 SQL 环境的关键特性,用于扩展系统的内置功能。 UDF允许开发人员通过抽象其低级语言实现来在更高级语言(如SQL)中启用新功能。 Apache Spark 也不例外,并且提供了用于将 UDF 与 Spark SQL工作流集成的各种选项。

##示例

BaseSparkSession/**  * 得到SparkSession  * 首先 extends BaseSparkSession  * 本地: val spark = sparkSession(true)  * 集群:  val spark = sparkSession()  */class BaseSparkSession {  var appName = "sparkSession"  var master = "spark://standalone.com:7077" //本地模式:local     standalone:spark://master:7077  def sparkSession(): SparkSession = {    val spark = SparkSession.builder      .master(master)      .appName(appName)      .config("spark.eventLog.enabled","true")      .config("spark.history.fs.logDirectory","hdfs://standalone.com:9000/spark/log/historyEventLog")      .config("spark.eventLog.dir","hdfs://standalone.com:9000/spark/log/historyEventLog")      .getOrCreate()    spark.sparkContext.addJar("/opt/n_001_workspaces/bigdata/spark-scala-maven-2.4.0/target/spark-scala-maven-2.4.0-1.0-SNAPSHOT.jar")    //import spark.implicits._    spark  }  def sparkSession(isLocal:Boolean = false): SparkSession = {    if(isLocal){      master = "local"      val spark = SparkSession.builder        .master(master)        .appName(appName)        .getOrCreate()      //spark.sparkContext.addJar("/opt/n_001_workspaces/bigdata/spark-scala-maven-2.4.0/target/spark-scala-maven-2.4.0-1.0-SNAPSHOT.jar")      //import spark.implicits._      spark    }else{      val spark = SparkSession.builder        .master(master)        .appName(appName)        .config("spark.eventLog.enabled","true")        .config("spark.history.fs.logDirectory","hdfs://standalone.com:9000/spark/log/historyEventLog")        .config("spark.eventLog.dir","hdfs://standalone.com:9000/spark/log/historyEventLog")        .getOrCreate()     // spark.sparkContext.addJar("/opt/n_001_workspaces/bigdata/spark-scala-maven-2.4.0/target/spark-scala-maven-2.4.0-1.0-SNAPSHOT.jar")      //import spark.implicits._      spark    }  }  /**    * 得到当前工程的路径    * @return    */  def getProjectPath:String=System.getProperty("user.dir")}UDF (统计字段长度)/**  * 自定义匿名函数  * 功能: 得到某列数据长度的函数  */object Run extends BaseSparkSession{  def main(args: Array[String]): Unit = {    val spark = sparkSession(true)    val ds = spark.read.json("hdfs://standalone.com:9000/home/liuwen/data/employees.json")    ds.show()//    +-------+------+//    |   name|salary|//    +-------+------+//    |Michael|  3000|//    |   Andy|  4500|//    | Justin|  3500|//    |  Berta|  4000|//    +-------+------+    spark.udf.register("strLength",(str: String) => str.length())    ds.createOrReplaceTempView("employees")    spark.sql("select name,salary,strLength(name) as name_Length from employees").show()//    +-------+------+-----------+//    |   name|salary|name_Length|//    +-------+------+-----------+//    |Michael|  3000|          7|//    |   Andy|  4500|          4|//    | Justin|  3500|          6|//    |  Berta|  4000|          5|//    +-------+------+-----------+    spark.stop()  }}UDF (字段转成大写)import com.opensource.bigdata.spark.standalone.base.BaseSparkSession/**  * 自定义匿名函数  * 功能: 得到某列数据长度的函数  */object Run extends BaseSparkSession{  def main(args: Array[String]): Unit = {    val spark = sparkSession(true)    val ds = spark.read.json("hdfs://standalone.com:9000/home/liuwen/data/employees.json")    ds.show()//    +-------+------+//    |   name|salary|//    +-------+------+//    |Michael|  3000|//    |   Andy|  4500|//    | Justin|  3500|//    |  Berta|  4000|//    +-------+------+    import org.apache.spark.sql.functions._    val strUpper = udf((str: String) => str.toUpperCase())    import spark.implicits._    ds.withColumn("toUpperCase", strUpper($"name")).show//    +-------+------+-----------+//    |   name|salary|toUpperCase|//    +-------+------+-----------+//    |Michael|  3000|    MICHAEL|//    |   Andy|  4500|       ANDY|//    | Justin|  3500|     JUSTIN|//    |  Berta|  4000|      BERTA|//    +-------+------+-----------+    spark.stop()  }}UDAFcountpackage com.opensource.bigdata.spark.sql.n_08_spark_udaf.n_01_spark_udaf_countimport com.opensource.bigdata.spark.standalone.base.BaseSparkSessionimport org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._/**  * ).initialize()方法,初使使,即没数据时的值  * ).update() 方法把每一行的数据进行计算,放到缓冲对象中  * ).merge() 把每个分区,缓冲对象进行合并  * ).evaluate()计算结果表达式,把缓冲对象中的数据进行最终计算  */object Run2 extends BaseSparkSession{  object CustomerCount extends UserDefinedAggregateFunction{    //聚合函数的输入参数数据类型    def inputSchema: StructType = {      StructType(StructField("inputColumn",StringType) :: Nil)    }    //中间缓存的数据类型    def bufferSchema: StructType = {      StructType(StructField("sum",LongType)  :: Nil)    }    //最终输出结果的数据类型    def dataType: DataType = LongType    def deterministic: Boolean = true    //初始值,要是DataSet没有数据,就返回该值    def initialize(buffer: MutableAggregationBuffer): Unit = {      buffer(0) = 0L    }    /**      *      * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中      * @param input      */    def update(buffer: MutableAggregationBuffer, input: Row): Unit ={      if(!input.isNullAt(0)){        buffer(0) = buffer.getLong(0) + 1      }    }    /**      * 相当于把每个分区的数据进行汇总      * @param buffer1  分区一的数据      * @param buffer2  分区二的数据      */    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={      buffer1(0) = buffer1.getLong(0) +buffer2.getLong(0)  //   salary    }    //计算最终的结果    def evaluate(buffer: Row): Long = buffer.getLong(0)  }  def main(args: Array[String]): Unit = {    val spark = sparkSession(true)    spark.udf.register("customerCount",CustomerCount)    val df = spark.read.json("hdfs://standalone.com:9000/home/liuwen/data/employees.json")    df.createOrReplaceTempView("employees")    val sqlDF = spark.sql("select customerCount(name)  as average_salary from employees  ")    df.show()//    +-------+------+//    |   name|salary|//    +-------+------+//    |Michael|  3000|//    |   Andy|  4500|//    | Justin|  3500|//    |  Berta|  4000|//    +-------+------+    sqlDF.show()//    +--------------+//    |average_salary|//    +--------------+//    |           4.0|//    +--------------+    spark.stop()  }}maxpackage com.opensource.bigdata.spark.sql.n_08_spark_udaf.n_03_spark_udaf_sumimport com.opensource.bigdata.spark.standalone.base.BaseSparkSessionimport org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._/**  * ).initialize()方法,初使使,即没数据时的值  * ).update() 方法把每一行的数据进行计算,放到缓冲对象中  * ).merge() 把每个分区,缓冲对象进行合并  * ).evaluate()计算结果表达式,把缓冲对象中的数据进行最终计算  */object Run extends BaseSparkSession{  object CustomerSum extends UserDefinedAggregateFunction{    //聚合函数的输入参数数据类型    def inputSchema: StructType = {      StructType(StructField("inputColumn",LongType) :: Nil)    }    //中间缓存的数据类型    def bufferSchema: StructType = {      StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)    }    //最终输出结果的数据类型    def dataType: DataType = LongType    def deterministic: Boolean = true    //初始值,要是DataSet没有数据,就返回该值    def initialize(buffer: MutableAggregationBuffer): Unit = {      buffer(0) = 0L    }    /**      *      * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中      * @param input      */    def update(buffer: MutableAggregationBuffer, input: Row): Unit ={      if(!input.isNullAt(0)){        buffer(0) =   buffer.getLong(0) + input.getLong(0)      }    }    /**      * 相当于把每个分区的数据进行汇总      * @param buffer1  分区一的数据      * @param buffer2  分区二的数据      */    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)    }    //计算最终的结果    def evaluate(buffer: Row): Long = buffer.getLong(0)  }  def main(args: Array[String]): Unit = {    val spark = sparkSession(true)    spark.udf.register("customerSum",CustomerSum)    val df = spark.read.json("hdfs://standalone.com:9000/home/liuwen/data/employees.json")    df.createOrReplaceTempView("employees")    val sqlDF = spark.sql("select customerSum(salary)  as average_salary from employees  ")    df.show//    +-------+------+//    |   name|salary|//    +-------+------+//    |Michael|  3000|//    |   Andy|  4500|//    | Justin|  3500|//    |  Berta|  4000|//    +-------+------+    sqlDF.show()//    +--------------+//    |average_salary|//    +--------------+//    |        15000|//    +--------------+    spark.stop()  }}averagepackage com.opensource.bigdata.spark.sql.n_08_spark_udaf.n_04_spark_udaf_averageimport com.opensource.bigdata.spark.standalone.base.BaseSparkSessionimport org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._object Run extends BaseSparkSession{  object MyAverage extends UserDefinedAggregateFunction{    //聚合函数的输入参数数据类型    def inputSchema: StructType = {      StructType(StructField("inputColumn",LongType) :: Nil)    }    //中间缓存的数据类型    def bufferSchema: StructType = {      StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)    }    //最终输出结果的数据类型    def dataType: DataType = DoubleType    def deterministic: Boolean = true    //初始值,要是DataSet没有数据,就返回该值    def initialize(buffer: MutableAggregationBuffer): Unit = {      buffer(0) = 0L      buffer(1) = 0L    }    /**      *      * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中      * @param input      */    def update(buffer: MutableAggregationBuffer, input: Row): Unit ={      if(!input.isNullAt(0)){        buffer(0) = buffer.getLong(0) + input.getLong(0)   // salary        buffer(1) = buffer.getLong(1) + 1  // count      }    }    /**      * 相当于把每个分区的数据进行汇总      * @param buffer1  分区一的数据      * @param buffer2  分区二的数据      */    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={      buffer1(0) = buffer1.getLong(0) +buffer2.getLong(0)  //   salary      buffer1(1) = buffer1.getLong(1) +buffer2.getLong(1)  // count    }    //计算最终的结果    def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)  }  def main(args: Array[String]): Unit = {    val spark = sparkSession(true)    spark.udf.register("MyAverage",MyAverage)    val df = spark.read.json("hdfs://standalone.com:9000/home/liuwen/data/employees.json")    df.createOrReplaceTempView("employees")    val sqlDF = spark.sql("select MyAverage(salary)  as average_salary from employees  ")    sqlDF.show()    spark.stop()  }}group by maxpackage com.opensource.bigdata.spark.sql.n_08_spark_udaf.n_05_spark_udaf_groupby_maximport com.opensource.bigdata.spark.standalone.base.BaseSparkSessionimport org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._/**  * ).initialize()方法,初使使,即没数据时的值  * ).update() 方法把每一行的数据进行计算,放到缓冲对象中  * ).merge() 把每个分区,缓冲对象进行合并  * ).evaluate()计算结果表达式,把缓冲对象中的数据进行最终计算  */object Run extends BaseSparkSession{  object CustomerMax extends UserDefinedAggregateFunction{    //聚合函数的输入参数数据类型    def inputSchema: StructType = {      StructType(StructField("inputColumn",LongType) :: Nil)    }    //中间缓存的数据类型    def bufferSchema: StructType = {      StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)    }    //最终输出结果的数据类型    def dataType: DataType = LongType    def deterministic: Boolean = true    //初始值,要是DataSet没有数据,就返回该值    def initialize(buffer: MutableAggregationBuffer): Unit = {      buffer(0) = 0L    }    /**      *      * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中      * @param input      */    def update(buffer: MutableAggregationBuffer, input: Row): Unit ={      if(!input.isNullAt(0)){        if(input.getLong(0) > buffer.getLong(0)){          buffer(0) = input.getLong(0)        }      }    }    /**      * 相当于把每个分区的数据进行汇总      * @param buffer1  分区一的数据      * @param buffer2  分区二的数据      */    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={      if( buffer2.getLong(0) >  buffer1.getLong(0)) buffer1(0) = buffer2.getLong(0)    }    //计算最终的结果    def evaluate(buffer: Row): Long = buffer.getLong(0)  }  def main(args: Array[String]): Unit = {    val spark = sparkSession(true)    spark.udf.register("customerMax",CustomerMax)    val df = spark.read.json("hdfs://standalone.com:9000/home/liuwen/data/employeesCN.json")    df.createOrReplaceTempView("employees")    val sqlDF = spark.sql("select gender,customerMax(salary)  as average_salary from employees group by gender  ")    df.show//    +------+----+------+//    |gender|name|salary|//    +------+----+------+//    |    男|小王| 30000|//    |    女|小丽| 50000|//    |    男|小军| 80000|//    |    女|小李| 90000|//    +------+----+------+    sqlDF.show()//    +------+--------------+//    |gender|average_salary|//    +------+--------------+//    |    男|       80000|//    |    女|       90000|//    +------+--------------+    spark.stop()  }}其它支持

转自 开源中国
地址 https://my.oschina.net/u/723009/blog/2989933






欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/) 黑马程序员IT技术论坛 X3.2