Skip to content

Commit 03c880f

Browse files
committed
Support running nds_power over spark connect
Signed-off-by: Bobby Wang <[email protected]>
1 parent cd9e29f commit 03c880f

File tree

3 files changed

+62
-22
lines changed

3 files changed

+62
-22
lines changed

nds/PysparkBenchReport.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535
import time
3636
import traceback
3737
from typing import Callable
38-
from pyspark.sql import SparkSession
3938

40-
import python_listener
39+
from pyspark.sql import SparkSession
4140

4241
class PysparkBenchReport:
4342
"""Class to generate json summary report for a benchmark
@@ -57,6 +56,36 @@ def __init__(self, spark_session: SparkSession, query_name) -> None:
5756
'query': query_name,
5857
}
5958

59+
def _is_above_spark_4(self):
60+
return self.spark_session.version >= "4.0.0"
61+
62+
def _register_python_listener(self):
63+
# Register PythonListener
64+
if self._is_above_spark_4():
65+
# is_remote_only is added starting from 4.0.0
66+
from pyspark import is_remote_only
67+
if is_remote_only():
68+
# We can't use Py4J in Spark Connect
69+
return None
70+
71+
listener = None
72+
try:
73+
import python_listener
74+
listener = python_listener.PythonListener()
75+
listener.register()
76+
except TypeError as e:
77+
print("Not found com.nvidia.spark.rapids.listener.Manager", str(e))
78+
return listener
79+
80+
def _get_spark_conf(self):
81+
if self._is_above_spark_4():
82+
from pyspark import is_remote_only
83+
if is_remote_only():
84+
return self.spark_session.conf.getAll
85+
86+
return self.spark_session.sparkContext._conf.getAll()
87+
88+
6089
def report_on(self, fn: Callable, warmup_iterations = 0, iterations = 1, *args):
6190
"""Record a function for its running environment, running status etc. and exclude sentive
6291
information like tokens, secret and password Generate summary in dict format for it.
@@ -67,20 +96,14 @@ def report_on(self, fn: Callable, warmup_iterations = 0, iterations = 1, *args):
6796
Returns:
6897
dict: summary of the fn
6998
"""
70-
spark_conf = dict(self.spark_session.sparkContext._conf.getAll())
99+
spark_conf = dict(self._get_spark_conf())
71100
env_vars = dict(os.environ)
72101
redacted = ["TOKEN", "SECRET", "PASSWORD"]
73102
filtered_env_vars = dict((k, env_vars[k]) for k in env_vars.keys() if not (k in redacted))
74103
self.summary['env']['envVars'] = filtered_env_vars
75104
self.summary['env']['sparkConf'] = spark_conf
76105
self.summary['env']['sparkVersion'] = self.spark_session.version
77-
listener = None
78-
try:
79-
listener = python_listener.PythonListener()
80-
listener.register()
81-
except TypeError as e:
82-
print("Not found com.nvidia.spark.rapids.listener.Manager", str(e))
83-
listener = None
106+
listener = self._register_python_listener()
84107
if listener is not None:
85108
print("TaskFailureListener is registered.")
86109
try:

nds/jvm_listener/pom.xml

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,35 @@
2626
<properties>
2727
<maven.compiler.source>8</maven.compiler.source>
2828
<maven.compiler.target>8</maven.compiler.target>
29+
<spark.version>3.1.2</spark.version>
30+
<scala.binary.version>2.12</scala.binary.version>
31+
<scala.version>2.12.18</scala.version>
2932
</properties>
33+
34+
<profiles>
35+
<profile>
36+
<id>spark4</id>
37+
<properties>
38+
<spark.version>4.0.0</spark.version>
39+
<scala.binary.version>2.13</scala.binary.version>
40+
<maven.compiler.source>17</maven.compiler.source>
41+
<maven.compiler.target>17</maven.compiler.target>
42+
<scala.version>2.13.16</scala.version>
43+
</properties>
44+
</profile>
45+
</profiles>
3046
<dependencies>
3147
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
3248
<dependency>
3349
<groupId>org.apache.spark</groupId>
34-
<artifactId>spark-core_2.12</artifactId>
35-
<version>3.1.2</version>
50+
<artifactId>spark-core_${scala.binary.version}</artifactId>
51+
<version>${spark.version}</version>
3652
</dependency>
3753
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql -->
3854
<dependency>
3955
<groupId>org.apache.spark</groupId>
40-
<artifactId>spark-sql_2.12</artifactId>
41-
<version>3.1.2</version>
56+
<artifactId>spark-sql_${scala.binary.version}</artifactId>
57+
<version>${spark.version}</version>
4258
<scope>provided</scope>
4359
</dependency>
4460
</dependencies>
@@ -50,10 +66,6 @@
5066
<groupId>org.apache.maven.plugins</groupId>
5167
<artifactId>maven-compiler-plugin</artifactId>
5268
<version>3.8.1</version>
53-
<configuration>
54-
<source>1.8</source>
55-
<target>1.8</target>
56-
</configuration>
5769
</plugin>
5870
<plugin>
5971
<groupId>org.scala-tools</groupId>

nds/nds_power.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def setup_tables(spark_session, input_prefix, input_format, use_decimal, executi
146146
Returns:
147147
execution_time_list: a list recording query execution time.
148148
"""
149-
spark_app_id = spark_session.sparkContext.applicationId
149+
spark_app_id = spark_session.conf.get("spark.app.id")
150150
# Create TempView for tables
151151
for table_name in get_schemas(False).keys():
152152
start = int(time.time() * 1000)
@@ -331,7 +331,7 @@ def run_query_stream(input_prefix,
331331
if input_format == 'delta' and delta_unmanaged:
332332
# Register tables for Delta Lake. This is only needed for unmanaged tables.
333333
execution_time_list = register_delta_tables(spark_session, input_prefix, execution_time_list)
334-
spark_app_id = spark_session.sparkContext.applicationId
334+
spark_app_id = spark_session.conf.get("spark.app.id")
335335
if input_format != 'iceberg' and input_format != 'delta' and not hive_external:
336336
execution_time_list = setup_tables(spark_session, input_prefix, input_format, use_decimal,
337337
execution_time_list)
@@ -347,7 +347,9 @@ def run_query_stream(input_prefix,
347347
power_start = int(time.time())
348348
for query_name, q_content in query_dict.items():
349349
# show query name in Spark web UI
350-
spark_session.sparkContext.setJobGroup(query_name, query_name)
350+
spark_session.conf.set("spark.job.description", query_name)
351+
spark_session.conf.set("spark.jobGroup.id", query_name)
352+
spark_session.conf.set("spark.job.interruptOnCancel", "false")
351353
print("====== Run {} ======".format(query_name))
352354
q_report = PysparkBenchReport(spark_session, query_name)
353355
summary = q_report.report_on(run_one_query,warmup_iterations,
@@ -374,10 +376,13 @@ def run_query_stream(input_prefix,
374376
else:
375377
summary_prefix = os.path.join(json_summary_folder, '')
376378
q_report.write_summary(prefix=summary_prefix)
379+
spark_session.conf.unset("spark.job.description")
380+
spark_session.conf.unset("spark.jobGroup.id")
381+
spark_session.conf.unset("spark.job.interruptOnCancel")
377382
power_end = int(time.time())
378383
power_elapse = int((power_end - power_start)*1000)
379384
if not keep_sc:
380-
spark_session.sparkContext.stop()
385+
spark_session.stop()
381386
total_time_end = time.time()
382387
total_elapse = int((total_time_end - total_time_start)*1000)
383388
print("====== Power Test Time: {} milliseconds ======".format(power_elapse))

0 commit comments

Comments
 (0)