|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | +# |
| 19 | +# ----- |
| 20 | +# |
| 21 | +# Certain portions of the contents of this file are derived from TPC-H version 3.0.1 |
| 22 | +# (retrieved from www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). |
| 23 | +# Such portions are subject to copyrights held by Transaction Processing Performance Council (“TPC”) |
| 24 | +# and licensed under the TPC EULA (a copy of which accompanies this file as “TPC EULA” and is also |
| 25 | +# available at http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) (the “TPC EULA”). |
| 26 | +# |
| 27 | +# You may not use this file except in compliance with the TPC EULA. |
| 28 | +# DISCLAIMER: Portions of this file is derived from the TPC-H Benchmark and as such any results |
| 29 | +# obtained using this file are not comparable to published TPC-H Benchmark results, as the results |
| 30 | +# obtained from using this file do not comply with the TPC-H Benchmark. |
| 31 | +# |
| 32 | + |
| 33 | +import argparse |
| 34 | +import glob |
| 35 | +import json |
| 36 | +import math |
| 37 | +import os |
| 38 | +import re |
| 39 | +import time |
| 40 | +from decimal import Decimal |
| 41 | + |
| 42 | +from pyspark.sql import DataFrame, SparkSession |
| 43 | +from pyspark.sql.types import DoubleType, FloatType |
| 44 | +from pyspark.sql.functions import col |
| 45 | + |
| 46 | +from nds_h_power import gen_sql_from_stream, get_query_subset |
| 47 | + |
| 48 | +SKIP_QUERIES = [ |
| 49 | + 'query15_part1', # create view query |
| 50 | + 'query15_part3', # drop view query |
| 51 | +] |
| 52 | +SKIP_COLUMNS = { |
| 53 | + 'query18': ['o_orderkey'], # non-deterministic output: https://github.com/NVIDIA/spark-rapids-benchmarks/pull/198#issuecomment-2403837688 |
| 54 | +} |
| 55 | + |
| 56 | + |
| 57 | +def compare_results(spark_session: SparkSession, |
| 58 | + input1: str, |
| 59 | + input2: str, |
| 60 | + input1_format: str, |
| 61 | + input2_format: str, |
| 62 | + ignore_ordering: bool, |
| 63 | + query_name: str, |
| 64 | + use_iterator=False, |
| 65 | + max_errors=10, |
| 66 | + epsilon=0.00001) -> bool: |
| 67 | + """Giving 2 paths of input query output data, compare them row by row, value by value to see if |
| 68 | + the results match or not. |
| 69 | +
|
| 70 | + Args: |
| 71 | + spark_session (SparkSession): Spark Session to hold the comparison |
| 72 | + input1 (str): path for the first input data |
| 73 | + input2 (str): path for the second input data |
| 74 | + input1_format (str): data source format for input1, e.g. parquet, orc |
| 75 | + input2_format (str): data source format for input2, e.g. parquet, orc |
| 76 | + ignore_ordering (bool): whether ignoring the order of input data. |
| 77 | + If true, we will order by ourselves. |
| 78 | + query_name (str): Query name. |
| 79 | + use_iterator (bool, optional): When set to true, use `toLocalIterator` to load one partition |
| 80 | + at a time into driver memory, reducing memory usage at the cost of performance because |
| 81 | + processing will be single-threaded. Defaults to False. |
| 82 | + max_errors (int, optional): Maximum number of differences to report. Defaults to 10. |
| 83 | + epsilon (float, optional): Allow for differences in precision when comparing floating point |
| 84 | + values. Defaults to 0.00001. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + bool: True if result matches otherwise False |
| 88 | + """ |
| 89 | + if query_name in SKIP_QUERIES: |
| 90 | + return True |
| 91 | + |
| 92 | + df1 = spark_session.read.format(input1_format).load(input1) |
| 93 | + df2 = spark_session.read.format(input2_format).load(input2) |
| 94 | + count1 = df1.count() |
| 95 | + count2 = df2.count() |
| 96 | + |
| 97 | + if(count1 == count2): |
| 98 | + #TODO: need partitioned collect for NDS? there's no partitioned output currently |
| 99 | + result1 = collect_results(df1, query_name, ignore_ordering, use_iterator) |
| 100 | + result2 = collect_results(df2, query_name, ignore_ordering, use_iterator) |
| 101 | + |
| 102 | + errors = 0 |
| 103 | + i = 0 |
| 104 | + while i < count1 and errors < max_errors: |
| 105 | + lhs = next(result1) |
| 106 | + rhs = next(result2) |
| 107 | + if not rowEqual(list(lhs), list(rhs), epsilon): |
| 108 | + print(f"Row {i}: \n{list(lhs)}\n{list(rhs)}\n") |
| 109 | + errors += 1 |
| 110 | + i += 1 |
| 111 | + print(f"Processed {i} rows") |
| 112 | + |
| 113 | + if errors == max_errors: |
| 114 | + print(f"Aborting comparison after reaching maximum of {max_errors} errors") |
| 115 | + return False |
| 116 | + elif errors == 0: |
| 117 | + print("Results match") |
| 118 | + return True |
| 119 | + else: |
| 120 | + print(f"There were {errors} errors") |
| 121 | + return False |
| 122 | + else: |
| 123 | + print(f"DataFrame row counts do not match: {count1} != {count2}") |
| 124 | + return False |
| 125 | + |
| 126 | +def collect_results(df: DataFrame, |
| 127 | + query_name: str, |
| 128 | + ignore_ordering: bool, |
| 129 | + use_iterator: bool): |
| 130 | + # skip output for specific query columns |
| 131 | + if query_name in SKIP_COLUMNS: |
| 132 | + df = df.drop(*SKIP_COLUMNS[query_name]) |
| 133 | + |
| 134 | + # apply sorting if specified |
| 135 | + non_float_cols = [col(field.name) for field in df.schema.fields |
| 136 | + if field.dataType.typeName() not in (FloatType.typeName(), DoubleType.typeName())] |
| 137 | + float_cols = [col(field.name) for field in df.schema.fields |
| 138 | + if field.dataType.typeName() in (FloatType.typeName(), DoubleType.typeName())] |
| 139 | + if ignore_ordering: |
| 140 | + df = df.sort(non_float_cols + float_cols) |
| 141 | + |
| 142 | + # TODO: do we still need this for NDS? Query outputs are usually 1 - 100 rows, |
| 143 | + # there should'nt be memory pressure. |
| 144 | + if use_iterator: |
| 145 | + it = df.toLocalIterator() |
| 146 | + else: |
| 147 | + print("Collecting rows from DataFrame") |
| 148 | + t1 = time.time() |
| 149 | + rows = df.collect() |
| 150 | + t2 = time.time() |
| 151 | + print(f"Collected {len(rows)} rows in {t2-t1} seconds") |
| 152 | + it = iter(rows) |
| 153 | + return it |
| 154 | + |
| 155 | + |
| 156 | +def rowEqual(row1, row2, epsilon): |
| 157 | + # only simple types in a row for NDS results |
| 158 | + return all([compare(lhs, rhs, epsilon) for lhs, rhs in zip(row1, row2)]) |
| 159 | + |
| 160 | + |
| 161 | +def compare(expected, actual, epsilon=0.00001): |
| 162 | + #TODO 1: we can optimize this with case-match after Python 3.10 |
| 163 | + #TODO 2: we can support complex data types like nested type if needed in the future. |
| 164 | + # now NDS only contains simple data types. |
| 165 | + if isinstance(expected, float) and isinstance(actual, float): |
| 166 | + # Double is converted to float in pyspark... |
| 167 | + if math.isnan(expected) and math.isnan(actual): |
| 168 | + return True |
| 169 | + return math.isclose(expected, actual, rel_tol=epsilon) |
| 170 | + |
| 171 | + if isinstance(expected, Decimal) and isinstance(actual, Decimal): |
| 172 | + return math.isclose(expected, actual, rel_tol=epsilon) |
| 173 | + |
| 174 | + return expected == actual |
| 175 | + |
| 176 | +def iterate_queries(spark_session: SparkSession, |
| 177 | + input1: str, |
| 178 | + input2: str, |
| 179 | + input1_format: str, |
| 180 | + input2_format: str, |
| 181 | + ignore_ordering: bool, |
| 182 | + query_dict: dict, |
| 183 | + use_iterator=False, |
| 184 | + max_errors=10, |
| 185 | + epsilon=0.00001): |
| 186 | + # Iterate each query folder for a Power Run output |
| 187 | + # Providing a list instead of hard-coding all queires is to satisfy the arbitary queries run. |
| 188 | + unmatch_queries = [] |
| 189 | + for query_name in query_dict.keys(): |
| 190 | + sub_input1 = input1 + '/' + query_name |
| 191 | + sub_input2 = input2 + '/' + query_name |
| 192 | + print(f"=== Comparing Query: {query_name} ===") |
| 193 | + result_equal = compare_results(spark_session, |
| 194 | + sub_input1, |
| 195 | + sub_input2, |
| 196 | + input1_format, |
| 197 | + input2_format, |
| 198 | + ignore_ordering, |
| 199 | + query_name, |
| 200 | + use_iterator=use_iterator, |
| 201 | + max_errors=max_errors, |
| 202 | + epsilon=epsilon) |
| 203 | + if result_equal == False: |
| 204 | + unmatch_queries.append(query_name) |
| 205 | + if len(unmatch_queries) != 0: |
| 206 | + print(f"=== Unmatch Queries: {unmatch_queries} ===") |
| 207 | + return unmatch_queries |
| 208 | + |
| 209 | +def update_summary(prefix, unmatch_queries): |
| 210 | + """set the queryValidationStatus field in json summary file. |
| 211 | + If the queryStatus is 'Completed' or 'CompletedWithTaskFailures' but validation failed, |
| 212 | + set to 'Fail'. |
| 213 | + If the queryStatus is 'Completed' or 'CompletedWithTaskFailures' and validation passed, |
| 214 | + set to 'Pass'. |
| 215 | + If the queryStatus is 'Failed', |
| 216 | + set to 'NotAttempted'. |
| 217 | +
|
| 218 | + Args: |
| 219 | + prefix (str): folder of the json summary files |
| 220 | + unmatch_queries ([str]): list of queries that failed validation |
| 221 | + """ |
| 222 | + if not os.path.exists(prefix): |
| 223 | + raise Exception("The json summary folder doesn't exist.") |
| 224 | + print(f"Updating queryValidationStatus in folder {prefix}.") |
| 225 | + for query_name in query_dict.keys(): |
| 226 | + summary_wildcard = prefix + f'/*{query_name}-*.json' |
| 227 | + file_glob = glob.glob(summary_wildcard) |
| 228 | + |
| 229 | + # Expect only one summary file for each query |
| 230 | + if len(file_glob) > 1: |
| 231 | + raise Exception(f"More than one summary file found for query {query_name} in folder {prefix}.") |
| 232 | + if len(file_glob) == 0: |
| 233 | + raise Exception(f"No summary file found for query {query_name} in folder {prefix}.") |
| 234 | + |
| 235 | + filename = file_glob[0] |
| 236 | + with open(filename, 'r') as f: |
| 237 | + summary = json.load(f) |
| 238 | + if query_name in unmatch_queries: |
| 239 | + if 'Completed' in summary['queryStatus'] or 'CompletedWithTaskFailures' in summary['queryStatus']: |
| 240 | + summary['queryValidationStatus'] = ['Fail'] |
| 241 | + else: |
| 242 | + summary['queryValidationStatus'] = ['NotAttempted'] |
| 243 | + else: |
| 244 | + summary['queryValidationStatus'] = ['Pass'] |
| 245 | + with open(filename, 'w') as f: |
| 246 | + json.dump(summary, f, indent=2) |
| 247 | + |
| 248 | +if __name__ == "__main__": |
| 249 | + parser = parser = argparse.ArgumentParser() |
| 250 | + parser.add_argument('input1', |
| 251 | + help='path of the first input data.') |
| 252 | + parser.add_argument('input2', |
| 253 | + help='path of the second input data.') |
| 254 | + parser.add_argument('query_stream_file', |
| 255 | + help='query stream file that contains NDS queries in specific order.') |
| 256 | + parser.add_argument('--input1_format', |
| 257 | + default='parquet', |
| 258 | + help='data source type for the first input data. e.g. parquet, orc. Default is: parquet.') |
| 259 | + parser.add_argument('--input2_format', |
| 260 | + default='parquet', |
| 261 | + help='data source type for the second input data. e.g. parquet, orc. Default is: parquet.') |
| 262 | + parser.add_argument('--max_errors', |
| 263 | + help='Maximum number of differences to report.', |
| 264 | + type=int, |
| 265 | + default=10) |
| 266 | + parser.add_argument('--epsilon', |
| 267 | + type=float, |
| 268 | + default=0.00001, |
| 269 | + help='Allow for differences in precision when comparing floating point values.' + |
| 270 | + ' Given 2 float numbers: 0.000001 and 0.000000, the diff of them is 0.000001' + |
| 271 | + ' which is less than 0.00001, so we regard this as acceptable and will not' + |
| 272 | + ' report a mismatch.') |
| 273 | + parser.add_argument('--ignore_ordering', |
| 274 | + action='store_true', |
| 275 | + help='Sort the data collected from the DataFrames before comparing them.') |
| 276 | + parser.add_argument('--use_iterator', |
| 277 | + action='store_true', |
| 278 | + help='When set, use `toLocalIterator` to load one partition at a' + |
| 279 | + ' time into driver memory, reducing memory usage at the cost of performance' + |
| 280 | + ' because processing will be single-threaded.') |
| 281 | + parser.add_argument('--json_summary_folder', |
| 282 | + help='path of a folder that contains json summary file for each query.') |
| 283 | + parser.add_argument('--sub_queries', |
| 284 | + type=lambda s: [x.strip() for x in s.split(',')], |
| 285 | + help='comma separated list of queries to compare. If not specified, all queries ' + |
| 286 | + 'in the stream file will be compared. e.g. "query1,query2,query3". Note, use ' + |
| 287 | + '"_part1" and "_part2" suffix e.g. query15_part1, query15_part2') |
| 288 | + args = parser.parse_args() |
| 289 | + query_dict = gen_sql_from_stream(args.query_stream_file) |
| 290 | + # if set sub_queries, only compare the specified queries |
| 291 | + if args.sub_queries: |
| 292 | + query_dict = get_query_subset(query_dict, args.sub_queries) |
| 293 | + session_builder = SparkSession.builder.appName("Validate Query Output").getOrCreate() |
| 294 | + unmatch_queries = iterate_queries(session_builder, |
| 295 | + args.input1, |
| 296 | + args.input2, |
| 297 | + args.input1_format, |
| 298 | + args.input2_format, |
| 299 | + args.ignore_ordering, |
| 300 | + query_dict, |
| 301 | + use_iterator=args.use_iterator, |
| 302 | + max_errors=args.max_errors, |
| 303 | + epsilon=args.epsilon) |
| 304 | + if args.json_summary_folder: |
| 305 | + update_summary(args.json_summary_folder, unmatch_queries) |
0 commit comments