Step 1: Define a scala UDF:
import org.apache.spark.sql.api.java.UDF1
import scala.collection.mutable
class GetMidVal extends UDF1[mutable.WrappedArray[Double], Double] {
override def call(arr: mutable.WrappedArray[Double]): Double = {
val n = arr.length
val arr_sorted = arr.sorted
val mid: Int = (n - 1) / 2
val mid_mod: Int = mid / 2
if(mid_mod == 1) {
val res: Double = arr_sorted.apply(mid)
res
} else {
val res: Double = (arr_sorted.apply(mid) + arr_sorted.apply(mid + 1)) / 2
res
}
}
}
Step 2: Use MyUDF in JavaSpark,no problem
public void callMidVal(SparkSession spark) {
spark.udf().register("getMidVal", new GetMidVal(), DataTypes.DoubleType);
StructType structType = new StructType();
structType = structType.add("id", DataTypes.StringType, false);
structType = structType.add("trx_amt", DataTypes.DoubleType, false);
List<Row> nums = new ArrayList<Row>();
nums.add(RowFactory.create("001", 3.0));
nums.add(RowFactory.create("001", 8.0));
nums.add(RowFactory.create("001", 10.0));
nums.add(RowFactory.create("001", 2.0));
Dataset<Row> df = spark.createDataFrame(nums, structType);
df.groupBy("id").agg(
functions.max("trx_amt"),
functions.min("trx_amt"),
functions.collect_list("trx_amt").alias("trx_amt_seq")
).withColumn("mid_val", functions.expr("getMidVal(trx_amt_seq)"))
.show();
}
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL basic example")
.master("local")
.getOrCreate();
// new Main().callConcat(spark);
new Main().callMidVal(spark);
}
+---+------------+------------+--------------------+-------+
| id|max(trx_amt)|min(trx_amt)| trx_amt_seq|mid_val|
+---+------------+------------+--------------------+-------+
|001| 10.0| 2.0|[3.0, 8.0, 10.0, ...| 5.5|
+---+------------+------------+--------------------+-------+
Step3: Use MyUDF in PySpark,raise Exception
def call_mid_val():
spark.udf.registerJavaFunction("getMidVal", "org.spark.udf.GetMidVal")
data = [
("001", 3.0), ("001", 2.3), ("001", 1.5),
("001", 4.2), ("001", 9.6),
("001", 7.3)
]
df = spark.createDataFrame(data, ['id', 'trx_amt'])
.groupBy("id").agg(F.collect_list("trx_amt").alias("trx_amt_seq"))
df.show(truncate=False)
df.printSchema()
print(df.dtypes)
print(df.schema)
spark.createDataFrame(data, ['id', 'trx_amt'])
.groupBy("id").agg(F.collect_list("trx_amt").alias("trx_amt_seq"))
.select(F.expr("getMidVal(trx_amt_seq)"))
.show()
Error Info:
24/06/15 22:17:57 ERROR Executor: Exception in task 0.0 in stage 5.0 (TID 3)
org.apache.spark.SparkException: Failed to execute user defined function (UDFRegistration$$Lambda$3116/0x0000000801e57608: (array) => struct<>)
at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:177)
at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:260)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:96)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:32)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:365)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:136)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
at java.base/java.lang.Thread.run(Thread.java:1623)
Caused by: java.lang.IllegalArgumentException: The value (3.0) of the type (java.lang.Double) cannot be converted to struct<>
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:267)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:241)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:106)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$.$anonfun$createToCatalystConverter$2(CatalystTypeConverters.scala:477)
… 18 more
Why return Type is not Double???
my spark version
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.3.3</version>
</dependency>
</dependencies>
all info
jz w is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.