I am using PySpark and I have created my own Imputer class as a subclass of Estimator as follows:
class CustomImputer(Estimator):
def __init__(self, inputCol, outputCol, strategy, group_column, impute_character=None):
super(ImputerGrouper, self).__init__()
....
def _fit(self, data):
....
def transform(self, new_data):
....
Then, I use this CustomImputer in a Pipeline and train it using cross validation, but I get this error:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
Cell In[38], line 13
10 #tvs = TrainValidationSplit(estimator=pipeline, evaluator=evaluator, estimatorParamMaps=params, trainRatio=0.7)
11 cv = CrossValidator(estimator=pipeline, evaluator=evaluator, estimatorParamMaps=params, numFolds=3)
---> 13 trained_cv = cv.fit(train)
14 best_lr = trained_cv.bestModel
File /usr/local/spark/python/pyspark/ml/base.py:205, in Estimator.fit(self, dataset, params)
203 return self.copy(params)._fit(dataset)
204 else:
--> 205 return self._fit(dataset)
206 else:
207 raise TypeError(
208 "Params must be either a param map or a list/tuple of param maps, "
209 "but got %s." % type(params)
210 )
File /usr/local/spark/python/pyspark/ml/tuning.py:847, in CrossValidator._fit(self, dataset)
841 train = datasets[i][0].cache()
843 tasks = map(
844 inheritable_thread_target,
845 _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
846 )
--> 847 for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
848 metrics_all[i][j] = metric
849 if collectSubModelsParam:
File /opt/conda/lib/python3.11/multiprocessing/pool.py:873, in IMapIterator.next(self, timeout)
871 if success:
872 return value
--> 873 raise value
File /opt/conda/lib/python3.11/multiprocessing/pool.py:125, in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
123 job, i, func, args, kwds = task
124 try:
--> 125 result = (True, func(*args, **kwds))
126 except Exception as e:
127 if wrap_exception and func is not _helper_reraises_exception:
File /usr/local/spark/python/pyspark/ml/tuning.py:847, in CrossValidator._fit.<locals>.<lambda>(f)
841 train = datasets[i][0].cache()
843 tasks = map(
844 inheritable_thread_target,
845 _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
846 )
--> 847 for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
848 metrics_all[i][j] = metric
849 if collectSubModelsParam:
File /usr/local/spark/python/pyspark/util.py:342, in inheritable_thread_target.<locals>.wrapped(*args, **kwargs)
340 assert SparkContext._active_spark_context is not None
341 SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties)
--> 342 return f(*args, **kwargs)
File /usr/local/spark/python/pyspark/ml/tuning.py:118, in _parallelFitTasks.<locals>.singleTask()
113 index, model = next(modelIter)
114 # TODO: duplicate evaluator to take extra params from input
115 # Note: Supporting tuning params in evaluator need update method
116 # `MetaAlgorithmReadWrite.getAllNestedStages`, make it return
117 # all nested stages and evaluators
--> 118 metric = eva.evaluate(model.transform(validation, epm[index]))
119 return index, metric, model if collectSubModel else None
File /usr/local/spark/python/pyspark/ml/evaluation.py:111, in Evaluator.evaluate(self, dataset, params)
109 return self.copy(params)._evaluate(dataset)
110 else:
--> 111 return self._evaluate(dataset)
112 else:
113 raise TypeError("Params must be a param map but got %s." % type(params))
File /usr/local/spark/python/pyspark/ml/evaluation.py:148, in JavaEvaluator._evaluate(self, dataset)
146 self._transfer_params_to_java()
147 assert self._java_obj is not None
--> 148 return self._java_obj.evaluate(dataset._jdf)
File /usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py:1322, in JavaMember.__call__(self, *args)
1316 command = proto.CALL_COMMAND_NAME +
1317 self.command_header +
1318 args_command +
1319 proto.END_COMMAND_PART
1321 answer = self.gateway_client.send_command(command)
-> 1322 return_value = get_return_value(
1323 answer, self.gateway_client, self.target_id, self.name)
1325 for temp_arg in temp_args:
1326 if hasattr(temp_arg, "_detach"):
File /usr/local/spark/python/pyspark/errors/exceptions/captured.py:179, in capture_sql_exception.<locals>.deco(*a, **kw)
177 def deco(*a: Any, **kw: Any) -> Any:
178 try:
--> 179 return f(*a, **kw)
180 except Py4JJavaError as e:
181 converted = convert_exception(e.java_exception)
File /usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.n".
328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
331 "An error occurred while calling {0}{1}{2}. Trace:n{3}n".
332 format(target_id, ".", name, value))
Py4JJavaError: An error occurred while calling o1856.evaluate.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 486.0 failed 1 times, most recent failure: Lost task 0.0 in stage 486.0 (TID 1945) (813302be1d05 executor driver): org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (`VectorAssembler$$Lambda$4596/0x00000040021625c0`: (struct<familiarity_imputed:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:198)
at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
at org.apache.spark.scheduler.Task.run(Task.scala:141)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: org.apache.spark.SparkException: Encountered null while assembling a row with handleInvalid = "error". Consider
removing nulls from dataset or using handleInvalid = "keep" or "skip".
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1(VectorAssembler.scala:291)
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1$adapted(VectorAssembler.scala:260)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.ml.feature.VectorAssembler$.assemble(VectorAssembler.scala:260)
at org.apache.spark.ml.feature.VectorAssembler.$anonfun$transform$6(VectorAssembler.scala:143)
... 22 more
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2844)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2780)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2779)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2779)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1242)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1242)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1242)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3048)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2982)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2971)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:984)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2463)
at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1046)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
at org.apache.spark.rdd.RDD.collect(RDD.scala:1045)
at org.apache.spark.RangePartitioner$.sketch(Partitioner.scala:320)
at org.apache.spark.RangePartitioner.<init>(Partitioner.scala:187)
at org.apache.spark.RangePartitioner.<init>(Partitioner.scala:167)
at org.apache.spark.rdd.OrderedRDDFunctions.$anonfun$sortByKey$1(OrderedRDDFunctions.scala:64)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
at org.apache.spark.rdd.OrderedRDDFunctions.sortByKey(OrderedRDDFunctions.scala:63)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.x$3$lzycompute(BinaryClassificationMetrics.scala:192)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.x$3(BinaryClassificationMetrics.scala:181)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.confusions$lzycompute(BinaryClassificationMetrics.scala:183)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.confusions(BinaryClassificationMetrics.scala:183)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.createCurve(BinaryClassificationMetrics.scala:275)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.roc(BinaryClassificationMetrics.scala:106)
at org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.areaUnderROC(BinaryClassificationMetrics.scala:126)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:101)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:568)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (`VectorAssembler$$Lambda$4596/0x00000040021625c0`: (struct<familiarity_imputed:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:198)
at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
at org.apache.spark.scheduler.Task.run(Task.scala:141)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
... 1 more
Caused by: org.apache.spark.SparkException: Encountered null while assembling a row with handleInvalid = "error". Consider
removing nulls from dataset or using handleInvalid = "keep" or "skip".
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1(VectorAssembler.scala:291)
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1$adapted(VectorAssembler.scala:260)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.ml.feature.VectorAssembler$.assemble(VectorAssembler.scala:260)
at org.apache.spark.ml.feature.VectorAssembler.$anonfun$transform$6(VectorAssembler.scala:143)
... 22 more
I have figured out this part of the error that says:
Caused by: org.apache.spark.SparkException: Encountered null while assembling a row with handleInvalid = "error". Consider
removing nulls from dataset or using handleInvalid = "keep" or "skip".
I have checked the results of the pipeline and there is not any record with null values. The thing is that when I leave out the features that I get after applying my custom Imputer it works without any problem.
Anyone has any idea why is it happening? Also, Have I created it properly my CustomImputer?
Fernanh98 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.