Scala 中累加器的创建与使用格式详解
1. 内置累加器的创建与使用格式
1.1 创建内置累加器
// 通过 SparkContext 创建
val acc = sc.longAccumulator("累加器名称") // Long 类型(默认初始值 0)
val accDouble = sc.doubleAccumulator("累加器名称") // Double 类型(初始值 0.0)
1.2 在任务中更新累加器
// 只能在行动操作(如 foreach、collect)中更新累加器
rdd.foreach { element =>if (满足条件) {acc.add(1) // 累加整数accDouble.add(5.5) // 累加浮点数}
}
1.3 在 Driver 端读取结果
println(s"累加器结果: ${acc.value}")
2. 自定义累加器的创建与使用格式
2.1 定义自定义累加器类
import org.apache.spark.util.AccumulatorV2// 定义输入类型和输出类型
class CustomAccumulator extends AccumulatorV2[输入类型, 输出类型] {private var _value: 输出类型 = 初始值// 判断累加器是否为空override def isZero: Boolean = _value == 初始值// 创建副本override def copy(): AccumulatorV2[输入类型, 输出类型] = {val newAcc = new CustomAccumulatornewAcc._value = this._valuenewAcc}// 重置累加器override def reset(): Unit = {_value = 初始值}// 添加元素(Executor 调用)override def add(v: 输入类型): Unit = {// 自定义累加逻辑(如将 v 合并到 _value)_value += v}// 合并其他累加器的值(Driver 调用)override def merge(other: AccumulatorV2[输入类型, 输出类型]): Unit = {_value += other.value}// 获取最终结果override def value: 输出类型 = _value
}
2.2 注册并使用自定义累加器
// 创建实例并注册
val customAcc = new CustomAccumulator()
sc.register(customAcc, "自定义累加器名称(可选)")// 在行动操作中更新
rdd.foreach { element =>customAcc.add(元素)
}// 读取结果
println(s"自定义累加器结果: ${customAcc.value}")
3. 完整示例:统计单词长度分布
3.1 代码实现
import org.apache.spark.{SparkConf, SparkContext}object WordLengthAccumulatorDemo {def main(args: Array[String]): Unit = {val conf = new SparkConf().setAppName("WordLengthAccumulator").setMaster("local[*]")val sc = new SparkContext(conf)// 创建内置累加器val shortWordAcc = sc.longAccumulator("ShortWords") // 统计短单词(长度 <=3)val longWordAcc = sc.longAccumulator("LongWords") // 统计长单词(长度 >3)// 读取数据并处理val textRDD = sc.textFile("hdfs://path/to/textfile.txt")textRDD.flatMap(_.split(" ")).foreach { word =>if (word.nonEmpty) {if (word.length <= 3) shortWordAcc.add(1)else longWordAcc.add(1)}}// 输出结果println(s"短单词数量: ${shortWordAcc.value}")println(s"长单词数量: ${longWordAcc.value}")sc.stop()}
}
3.2 输出示例
短单词数量: 120
长单词数量: 350
4. 关键注意事项
注意事项 | 正确做法 |
---|---|
只能在行动操作中更新累加器 | 确保在 foreach 、collect 等行动操作中调用 add() ,而非 map 、filter 等转换操作。 |
避免多次计算 RDD | 对 RDD 调用 persist() 或 cache() ,防止重复计算导致累加器重复累加。 |
自定义累加器需注册 | 通过 sc.register() 注册自定义累加器,否则可能引发序列化错误。 |
合并逻辑必须幂等 | 确保 merge() 方法正确处理重复数据(如集合合并用 addAll )。 |
5. 自定义累加器示例:统计唯一单词
5.1 定义累加器
import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable.HashSetclass UniqueWordsAccumulator extends AccumulatorV2[String, HashSet[String]] {private val _words = HashSet[String]()override def isZero: Boolean = _words.isEmptyoverride def copy(): AccumulatorV2[String, HashSet[String]] = {val newAcc = new UniqueWordsAccumulatornewAcc._words ++= this._wordsnewAcc}override def reset(): Unit = _words.clear()override def add(word: String): Unit = _words.add(word)override def merge(other: AccumulatorV2[String, HashSet[String]]): Unit = {_words ++= other.value}override def value: HashSet[String]] = _words
}
5.2 使用累加器
val uniqueWordsAcc = new UniqueWordsAccumulator()
sc.register(uniqueWordsAcc, "UniqueWords")val wordsRDD = sc.parallelize(List("apple", "banana", "apple", "orange"))
wordsRDD.foreach(word => uniqueWordsAcc.add(word))println(s"去重后的单词: ${uniqueWordsAcc.value.mkString(", ")}")
// 输出: apple, banana, orange
总结
-
创建格式:
-
内置累加器:
sc.longAccumulator("name")
-
自定义累加器:继承
AccumulatorV2
并实现方法,然后注册sc.register(acc)
-
-
使用格式:
-
在行动操作中调用
add()
-
通过
value
属性在 Driver 端读取结果
-
-
核心原则:
只在行动操作中更新累加器,避免重复计算和序列化问题。