diff --git a/nds-h/nds_h_power.py b/nds-h/nds_h_power.py index 47a3ded..fd3f075 100644 --- a/nds-h/nds_h_power.py +++ b/nds-h/nds_h_power.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -186,6 +186,8 @@ def run_query_stream(input_prefix, query_dict, time_log_output_path, sub_queries, + warmup_iterations, + iterations, input_format, output_path=None, keep_sc=False, @@ -237,7 +239,10 @@ def run_query_stream(input_prefix, spark_session.sparkContext.setJobGroup(query_name, query_name) print("====== Run {} ======".format(query_name)) q_report = PysparkBenchReport(spark_session, query_name) - summary = q_report.report_on(run_one_query, spark_session, + summary = q_report.report_on(run_one_query, + warmup_iterations, + iterations, + spark_session, q_content, query_name, output_path, @@ -346,6 +351,14 @@ def load_properties(filename): default='parquet') parser.add_argument('--property_file', help='property file for Spark configuration.') + parser.add_argument('--warmup_iterations', + type=int, + help='Number of warmup iterations for each query.', + default=0) + parser.add_argument('--iterations', + type=int, + help='Number of iterations for each query.', + default=1) args = parser.parse_args() query_dict = gen_sql_from_stream(args.query_stream_file) run_query_stream(args.input_prefix, @@ -353,6 +366,8 @@ def load_properties(filename): query_dict, args.time_log, args.sub_queries, + args.warmup_iterations, + args.iterations, args.input_format, args.output_prefix, args.keep_sc, diff --git a/nds/PysparkBenchReport.py b/nds/PysparkBenchReport.py index 154e3e5..068e295 100644 --- a/nds/PysparkBenchReport.py +++ b/nds/PysparkBenchReport.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -57,7 +57,7 @@ def __init__(self, spark_session: SparkSession, query_name) -> None: 'query': query_name, } - def report_on(self, fn: Callable, *args): + def report_on(self, fn: Callable, warmup_iterations = 0, iterations = 1, *args): """Record a function for its running environment, running status etc. and exclude sentive information like tokens, secret and password Generate summary in dict format for it. @@ -84,28 +84,41 @@ def report_on(self, fn: Callable, *args): if listener is not None: print("TaskFailureListener is registered.") try: - start_time = int(time.time() * 1000) - fn(*args) - end_time = int(time.time() * 1000) - if listener and len(listener.failures) != 0: - self.summary['queryStatus'].append("CompletedWithTaskFailures") - else: - self.summary['queryStatus'].append("Completed") + # warmup + for i in range(0, warmup_iterations): + fn(*args) except Exception as e: - # print the exception to ease debugging - print('ERROR BEGIN') + print('ERROR WHILE WARMUP BEGIN') print(e) traceback.print_tb(e.__traceback__) - print('ERROR END') - end_time = int(time.time() * 1000) - self.summary['queryStatus'].append("Failed") - self.summary['exceptions'].append(str(e)) - finally: - self.summary['startTime'] = start_time - self.summary['queryTimes'].append(end_time - start_time) - if listener is not None: - listener.unregister() - return self.summary + print('ERROR WHILE WARMUP END') + + start_time = int(time.time() * 1000) + self.summary['startTime'] = start_time + # run the query + for i in range(0, iterations): + try: + start_time = int(time.time() * 1000) + fn(*args) + end_time = int(time.time() * 1000) + if listener and len(listener.failures) != 0: + self.summary['queryStatus'].append("CompletedWithTaskFailures") + else: + self.summary['queryStatus'].append("Completed") + except Exception as e: + # print the exception to ease debugging + print('ERROR BEGIN') + print(e) + traceback.print_tb(e.__traceback__) + print('ERROR END') + end_time = int(time.time() * 1000) + self.summary['queryStatus'].append("Failed") + self.summary['exceptions'].append(str(e)) + finally: + self.summary['queryTimes'].append(end_time - start_time) + if listener is not None: + listener.unregister() + return self.summary def write_summary(self, prefix=""): """_summary_ diff --git a/nds/nds_power.py b/nds/nds_power.py index f7f0f99..eef10b0 100644 --- a/nds/nds_power.py +++ b/nds/nds_power.py @@ -231,6 +231,8 @@ def run_query_stream(input_prefix, time_log_output_path, extra_time_log_output_path, sub_queries, + warmup_iterations, + iterations, input_format="parquet", use_decimal=True, output_path=None, @@ -306,7 +308,9 @@ def run_query_stream(input_prefix, spark_session.sparkContext.setJobGroup(query_name, query_name) print("====== Run {} ======".format(query_name)) q_report = PysparkBenchReport(spark_session, query_name) - summary = q_report.report_on(run_one_query,spark_session, + summary = q_report.report_on(run_one_query,warmup_iterations, + iterations, + spark_session, profiler, q_content, query_name, @@ -314,7 +318,8 @@ def run_query_stream(input_prefix, output_format) print(f"Time taken: {summary['queryTimes']} millis for {query_name}") query_times = summary['queryTimes'] - execution_time_list.append((spark_app_id, query_name, query_times[0])) + for query_time in query_times: + execution_time_list.append((spark_app_id, query_name, query_time)) queries_reports.append(q_report) if json_summary_folder: # property_file e.g.: "property/aqe-on.properties" or just "aqe-off.properties" @@ -445,6 +450,14 @@ def load_properties(filename): help='Executable that is called just before/after a query executes.' + 'The executable is called like this ' + './hook {start|stop} output_root query_name.') + parser.add_argument('--warmup_iterations', + type=int, + help='Number of warmup iterations for each query.', + default=0) + parser.add_argument('--iterations', + type=int, + help='Number of iterations for each query.', + default=1) args = parser.parse_args() query_dict = gen_sql_from_stream(args.query_stream_file) run_query_stream(args.input_prefix, @@ -453,6 +466,8 @@ def load_properties(filename): args.time_log, args.extra_time_log, args.sub_queries, + args.warmup_iterations, + args.iterations, args.input_format, not args.floats, args.output_prefix,