Skip to content

Commit fedb012

Browse files
committed
Reapply "[SPARK-54218][PYTHON][SQL][TESTS] Add integrated tests for Scalar Pandas Iterator UDF"
This reverts commit 81be5fb.
1 parent 81be5fb commit fedb012

File tree

3 files changed

+100
-5
lines changed

3 files changed

+100
-5
lines changed

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
249249
binaryPythonDataSource
250250
}
251251

252-
private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
252+
private lazy val pandasScalarFunc: Array[Byte] = if (shouldTestPandasUDFs) {
253253
var binaryPandasFunc: Array[Byte] = null
254254
withTempPath { path =>
255255
Process(
@@ -272,6 +272,29 @@ object IntegratedUDFTestUtils extends SQLHelper {
272272
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
273273
}
274274

275+
private lazy val pandasScalarIterFunc: Array[Byte] = if (shouldTestPandasUDFs) {
276+
var binaryPandasFunc: Array[Byte] = null
277+
withTempPath { path =>
278+
Process(
279+
Seq(
280+
pythonExec,
281+
"-c",
282+
"from pyspark.sql.types import StringType; " +
283+
"from pyspark.serializers import CloudPickleSerializer; " +
284+
s"f = open('$path', 'wb');" +
285+
"f.write(CloudPickleSerializer().dumps((" +
286+
"lambda it: (x.apply(lambda v: None if v is None else str(v)) for x in it), " +
287+
"StringType())))"),
288+
None,
289+
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
290+
binaryPandasFunc = Files.readAllBytes(path.toPath)
291+
}
292+
assert(binaryPandasFunc != null)
293+
binaryPandasFunc
294+
} else {
295+
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
296+
}
297+
275298
private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) {
276299
var binaryPandasFunc: Array[Byte] = null
277300
withTempPath { path =>
@@ -1380,7 +1403,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
13801403
private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
13811404
name = name,
13821405
func = SimplePythonFunction(
1383-
command = pandasFunc.toImmutableArraySeq,
1406+
command = pandasScalarFunc.toImmutableArraySeq,
13841407
envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
13851408
pythonIncludes = List.empty[String].asJava,
13861409
pythonExec = pythonExec,
@@ -1410,6 +1433,60 @@ object IntegratedUDFTestUtils extends SQLHelper {
14101433
val prettyName: String = "Scalar Pandas UDF"
14111434
}
14121435

1436+
/**
1437+
* A Scalar Iterator Pandas UDF that takes one column, casts into string, executes the
1438+
* Python native function, and casts back to the type of input column.
1439+
*
1440+
* Virtually equivalent to:
1441+
*
1442+
* {{{
1443+
* from pyspark.sql.functions import pandas_udf, PandasUDFType
1444+
*
1445+
* df = spark.range(3).toDF("col")
1446+
* scalar_iter_udf = pandas_udf(
1447+
* lambda it: map(lambda x: x.apply(lambda v: str(v)), it),
1448+
* "string",
1449+
* PandasUDFType.SCALAR_ITER)
1450+
* casted_col = scalar_iter_udf(df.col.cast("string"))
1451+
* casted_col.cast(df.schema["col"].dataType)
1452+
* }}}
1453+
*/
1454+
case class TestScalarIterPandasUDF(
1455+
name: String,
1456+
returnType: Option[DataType] = None) extends TestUDF {
1457+
private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
1458+
name = name,
1459+
func = SimplePythonFunction(
1460+
command = pandasScalarIterFunc.toImmutableArraySeq,
1461+
envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
1462+
pythonIncludes = List.empty[String].asJava,
1463+
pythonExec = pythonExec,
1464+
pythonVer = pythonVer,
1465+
broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
1466+
accumulator = null),
1467+
dataType = StringType,
1468+
pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
1469+
udfDeterministic = true) {
1470+
1471+
override def builder(e: Seq[Expression]): Expression = {
1472+
assert(e.length == 1, "Defined UDF only has one column")
1473+
val expr = e.head
1474+
val rt = returnType.getOrElse {
1475+
assert(expr.resolved, "column should be resolved to use the same type " +
1476+
"as input. Try df(name) or df.col(name)")
1477+
expr.dataType
1478+
}
1479+
val pythonUDF = new PythonUDFWithoutId(
1480+
super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF])
1481+
Cast(pythonUDF, rt)
1482+
}
1483+
}
1484+
1485+
def apply(exprs: Column*): Column = udf(exprs: _*)
1486+
1487+
val prettyName: String = "Scalar Pandas Iterator UDF"
1488+
}
1489+
14131490
/**
14141491
* A Grouped Aggregate Pandas UDF that takes one column, executes the
14151492
* Python native function calculating the count of the column using pandas.
@@ -1606,6 +1683,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
16061683
def registerTestUDF(testUDF: TestUDF, session: classic.SparkSession): Unit = testUDF match {
16071684
case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf)
16081685
case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf)
1686+
case udf: TestScalarIterPandasUDF => session.udf.registerPython(udf.name, udf.udf)
16091687
case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf)
16101688
case udf: TestScalaUDF =>
16111689
val registry = session.sessionState.functionRegistry

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ import org.apache.spark.util.Utils
120120
* - Scalar Pandas UDF test case with a Scalar Pandas UDF registered as the name 'udf'
121121
* iff Python executable, pyspark, pandas and pyarrow are available.
122122
*
123+
* - Scalar Iterator Pandas UDF test case with a Scalar Iterator Pandas UDF registered
124+
* as the name 'udf' iff Python executable, pyspark, pandas and pyarrow are available.
125+
*
123126
* Therefore, UDF test cases should have single input and output files but executed by three
124127
* different types of UDFs. See 'udf/udf-inner-join.sql' as an example.
125128
*
@@ -193,6 +196,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
193196
s"pandas and/or pyarrow were not available in [$pythonExec].") {
194197
/* Do nothing */
195198
}
199+
case udfTestCase: SQLQueryTestSuite#UDFTest
200+
if udfTestCase.udf.isInstanceOf[TestScalarIterPandasUDF] && !shouldTestPandasUDFs =>
201+
ignore(s"${testCase.name} is skipped because pyspark," +
202+
s"pandas and/or pyarrow were not available in [$pythonExec].") {
203+
/* Do nothing */
204+
}
196205
case udfTestCase: SQLQueryTestSuite#UDFTest
197206
if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
198207
!shouldTestPandasUDFs =>
@@ -397,6 +406,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
397406
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs =>
398407
s"${testCase.name}${System.lineSeparator()}" +
399408
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
409+
case udfTestCase: SQLQueryTestSuite#UDFTest
410+
if udfTestCase.udf.isInstanceOf[TestScalarIterPandasUDF] && shouldTestPandasUDFs =>
411+
s"${testCase.name}${System.lineSeparator()}" +
412+
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
400413
case udfTestCase: SQLQueryTestSuite#UDFTest
401414
if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
402415
shouldTestPandasUDFs =>
@@ -446,12 +459,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
446459
// Create test cases of test types that depend on the input filename.
447460
val newTestCases: Seq[TestCase] = if (file.getAbsolutePath.startsWith(
448461
s"$inputFilePath${File.separator}udf${File.separator}postgreSQL")) {
449-
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf =>
462+
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
463+
TestScalarPandasUDF("udf"), TestScalarIterPandasUDF("udf")).map { udf =>
450464
UDFPgSQLTestCase(
451465
s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf)
452466
}
453467
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) {
454-
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf =>
468+
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
469+
TestScalarPandasUDF("udf"), TestScalarIterPandasUDF("udf")).map { udf =>
455470
UDFTestCase(
456471
s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf)
457472
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,12 @@ class ContinuousSuite extends ContinuousSuiteBase {
278278
s"Result set ${results.toSet} are not a superset of $expected!")
279279
}
280280

281-
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).foreach { udf =>
281+
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
282+
TestScalarPandasUDF("udf"), TestScalarIterPandasUDF("udf")).foreach { udf =>
282283
test(s"continuous mode with various UDFs - ${udf.prettyName}") {
283284
assume(
284285
shouldTestPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] ||
286+
shouldTestPandasUDFs && udf.isInstanceOf[TestScalarIterPandasUDF] ||
285287
shouldTestPythonUDFs && udf.isInstanceOf[TestPythonUDF] ||
286288
udf.isInstanceOf[TestScalaUDF])
287289

0 commit comments

Comments
 (0)