@@ -30,11 +30,14 @@ def generate_additional_params(
3030 for dtype , var in zip (
3131 additional_tensor_dtypes ,
3232 additional_tensor_names ,
33+ strict = True ,
3334 )
3435 ]
3536 + [
3637 f"{ dtype } { var } ;\n "
37- for dtype , var in zip (additional_scalar_dtypes , additional_scalar_names )
38+ for dtype , var in zip (
39+ additional_scalar_dtypes , additional_scalar_names , strict = True
40+ )
3841 ]
3942 )
4043 additional_func_params = "" .join (
@@ -48,7 +51,9 @@ def generate_additional_params(
4851 ]
4952 + [
5053 f", { dtype } { var } "
51- for dtype , var in zip (additional_scalar_dtypes , additional_scalar_names )
54+ for dtype , var in zip (
55+ additional_scalar_dtypes , additional_scalar_names , strict = True
56+ )
5257 ]
5358 )
5459 if is_sm90_template :
@@ -59,7 +64,9 @@ def generate_additional_params(
5964 if var .startswith ("maybe" )
6065 else f"params.additional_params.{ var } = static_cast<{ dtype } *>({ var } .data_ptr());"
6166 )
62- for dtype , var in zip (additional_tensor_dtypes , additional_tensor_names )
67+ for dtype , var in zip (
68+ additional_tensor_dtypes , additional_tensor_names , strict = True
69+ )
6370 ]
6471 + [
6572 f"params.additional_params.{ var } = { var } ;"
@@ -74,7 +81,9 @@ def generate_additional_params(
7481 if var .startswith ("maybe" )
7582 else f"params.{ var } = static_cast<{ dtype } *>({ var } .data_ptr());"
7683 )
77- for dtype , var in zip (additional_tensor_dtypes , additional_tensor_names )
84+ for dtype , var in zip (
85+ additional_tensor_dtypes , additional_tensor_names , strict = True
86+ )
7887 ]
7988 + [f"params.{ var } = { var } ;" for var in additional_scalar_names ]
8089 )
0 commit comments