0

spark 1.5에서 collect_set에 해당하는 함수를 말해 줄 수 있습니까?collect_set equivalent spark 1.5 UDAF 메소드 검증

collect_set (col (name))과 비슷한 결과를 얻는 방법이 있습니까? 올바른

class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction { 

    def inputSchema: StructType = 
    new StructType().add("inputCol", colType) 

    def bufferSchema: StructType = 
    new StructType().add("outputCol", ArrayType(colType)) 

    def dataType: DataType = ArrayType(colType) 

    def deterministic: Boolean = true 

    def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer.update(0, new scala.collection.mutable.ArrayBuffer[T]) 
    } 

    def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
    val list = buffer.getSeq[T](0) 
    if (!input.isNullAt(0)) { 
     val sales = input.getAs[T](0) 
     buffer.update(0, list:+sales) 
    } 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
    buffer1.update(0, buffer1.getSeq[T](0).toSet ++ buffer2.getSeq[T](0).toSet) 
    } 

    def evaluate(buffer: Row): Any = { 
    buffer.getSeq[T](0) 
    } 
} 

답변

2

그것은 코드 조회 :

이 올바른 방법입니다. 또한, 1.6.2에서 로컬 모드로 테스트했고 같은 결과를 얻었습니다 (아래 참조). DataFrame API를 사용하는 간단한 대안을 모르겠습니다. RDD를 사용하면 매우 간단하며 데이터 프레임이 완전히 구현되지 않아 RDD API를 1.5로 우회하는 경우가 있습니다.

scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x)) 
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show 
+---+-------+ 
| k| set| 
+---+-------+ 
| 0|[5, 10]| 
| 1| [1, 6]| 
| 2| [2, 7]| 
| 3| [3, 8]| 
| 4| [4, 9]| 
+---+-------+ 

그리고 당신은 그것을 고려하고 싶은 경우, (imroved 수) 초기 버전이 될 수있는 다른 집계를 만들고 싶어

def collectSet(df: DataFrame, k: Column, v: Column) = df 
    .select(k.as("k"),v.as("v")) 
    .map(r => (r.getInt(0),r.getInt(1))) 
    .groupByKey() 
    .mapValues(_.toSet.toList) 
    .toDF("k","v") 

를 수행하면, 당신은 할 수 없습니다 조인을 피하십시오.


scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5))) 
df: org.apache.spark.sql.DataFrame = [v: int, k: int] 

scala> val csudaf = new CollectSetFunction[Int](IntegerType) 

scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show 
+---+--------------+---------------------+ 
| k|collect_set(v)|CollectSetFunction(v)| 
+---+--------------+---------------------+ 
| 0|  [5, 10]|    [5, 10]| 
| 1|  [1, 6]|    [1, 6]| 
| 2|  [2, 7]|    [2, 7]| 
| 3|  [3, 8]|    [3, 8]| 
| 4|  [4, 9]|    [4, 9]| 
+---+--------------+---------------------+ 

테스트 2 :

scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10)) 
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint] 

scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b")) 
     .groupBy('a==='b).count.show 
+-------+-----+                 
|(a = b)|count| 
+-------+-----+ 
| true| 10| 
+-------+-----+