UserDefinedAggregateFunction을 사용할 수 있습니다. 아래 코드는 1.6.2에서 테스트되었습니다.
먼저 UserDefinedAggregateFunction을 확장하는 클래스를 만듭니다.
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class ModeUDAF extends UserDefinedAggregateFunction{
override def dataType: DataType = StringType
override def inputSchema: StructType = new StructType().add("input", StringType)
override def deterministic: Boolean = true
override def bufferSchema: StructType = new StructType().add("mode", MapType(StringType, LongType))
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map.empty[Any, Long]
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val buff0 = buffer.getMap[Any, Long](0)
val inp = input.get(0)
buffer(0) = buff0.updated(inp, buff0.getOrElse(inp, 0L) + 1L)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val mp1 = buffer1.getMap[Any, Long](0)
val mp2 = buffer2.getMap[Any, Long](0)
buffer1(0) = mp1 ++ mp2.map { case (k, v) => k -> (v + mp1.getOrElse(k, 0L)) }
}
override def evaluate(buffer: Row): Any = {
lazy val st = buffer.getMap[Any, Long](0).toStream
val mode = st.foldLeft(st.head){case (e, s) => if (s._2 > e._2) s else e}
mode._1
}
}
이후에는 다음과 같이 데이터 프레임과 함께 사용할 수 있습니다.
val modeColumnList = List("some", "column", "names") // or df.columns.toList
val modeAgg = new ModeUDAF()
val aggCols = modeColumnList.map(c => modeAgg(df(c)))
val aggregatedModeDF = df.agg(aggCols.head, aggCols.tail: _*)
aggregatedModeDF.show()
또한 최종 데이터 프레임에서 .collect를 사용하여 결과를 스칼라 데이터 구조로 수집 할 수 있습니다.
참고 :이 솔루션의 성능은 입력 열의 카디널리티에 따라 다릅니다.
감사합니다. 카디널리티가 낮을 때만 울리는 것을 볼 수 있습니다. 각 카테고리에 1,2,3 값만있는이 생성 된 데이터에서이 방법을 사용합니다.이 방법은 매우 느립니다. –