Skip to content

Commit a9d1cba

Browse files
lint changes and cleanup of status and output type checks
1 parent e514007 commit a9d1cba

File tree

3 files changed

+84
-90
lines changed

3 files changed

+84
-90
lines changed

onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ TEST(NvExecutionProviderTest, FP8CustomOpModel) {
503503
ASSERT_EQ(output_shape[1], 64);
504504

505505
// Verify output is FLOAT16
506-
ASSERT_EQ(output_tensor.DataType(), DataTypeImpl::GetType<MLFloat16>());
506+
ASSERT_TRUE(output_tensor.IsDataType<MLFloat16>());
507507

508508
LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP8 custom ops model run completed successfully";
509509
}
@@ -570,7 +570,7 @@ TEST(NvExecutionProviderTest, FP4CustomOpModel) {
570570
ASSERT_EQ(output_shape[1], 64);
571571

572572
// Verify output is FLOAT16
573-
ASSERT_EQ(output_tensor.DataType(), DataTypeImpl::GetType<MLFloat16>());
573+
ASSERT_TRUE(output_tensor.IsDataType<MLFloat16>());
574574

575575
LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP4 dynamic quantize model run completed successfully";
576576
}

onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc

Lines changed: 80 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -475,60 +475,60 @@ static std::vector<ONNX_NAMESPACE::OpSchema> CreateTRTFP8Schemas() {
475475
schemas.emplace_back();
476476
ONNX_NAMESPACE::OpSchema& fp8_quant_schema = schemas.back();
477477
fp8_quant_schema
478-
.SetName("TRT_FP8QuantizeLinear")
479-
.SetDomain("trt")
480-
.SinceVersion(1)
481-
.SetDoc("TensorRT FP8 Quantization - quantizes FP16 input to FP8")
482-
.Input(0, "X", "Input tensor in FP16", "T1")
483-
.Input(1, "scale", "Scale for quantization in FP16", "T1")
484-
.Output(0, "Y", "Quantized output tensor in FP8", "T2")
485-
.TypeConstraint("T1", {"tensor(float16)"}, "Input and scale must be float16")
486-
.TypeConstraint("T2", {"tensor(float8e4m3fn)"}, "Output must be float8e4m3fn")
487-
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
488-
// Output has same shape as input but FP8 type
489-
auto input_type = ctx.getInputType(0);
490-
if (input_type != nullptr && input_type->has_tensor_type()) {
491-
auto output_type = ctx.getOutputType(0);
492-
output_type->mutable_tensor_type()->set_elem_type(
493-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
494-
if (input_type->tensor_type().has_shape()) {
495-
*output_type->mutable_tensor_type()->mutable_shape() =
496-
input_type->tensor_type().shape();
478+
.SetName("TRT_FP8QuantizeLinear")
479+
.SetDomain("trt")
480+
.SinceVersion(1)
481+
.SetDoc("TensorRT FP8 Quantization - quantizes FP16 input to FP8")
482+
.Input(0, "X", "Input tensor in FP16", "T1")
483+
.Input(1, "scale", "Scale for quantization in FP16", "T1")
484+
.Output(0, "Y", "Quantized output tensor in FP8", "T2")
485+
.TypeConstraint("T1", {"tensor(float16)"}, "Input and scale must be float16")
486+
.TypeConstraint("T2", {"tensor(float8e4m3fn)"}, "Output must be float8e4m3fn")
487+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
488+
// Output has same shape as input but FP8 type
489+
auto input_type = ctx.getInputType(0);
490+
if (input_type != nullptr && input_type->has_tensor_type()) {
491+
auto output_type = ctx.getOutputType(0);
492+
output_type->mutable_tensor_type()->set_elem_type(
493+
ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
494+
if (input_type->tensor_type().has_shape()) {
495+
*output_type->mutable_tensor_type()->mutable_shape() =
496+
input_type->tensor_type().shape();
497+
}
497498
}
498-
}
499-
});
499+
});
500500

501501
// TRT_FP8DequantizeLinear schema
502502
schemas.emplace_back();
503503
ONNX_NAMESPACE::OpSchema& fp8_dequant_schema = schemas.back();
504504
fp8_dequant_schema
505-
.SetName("TRT_FP8DequantizeLinear")
506-
.SetDomain("trt")
507-
.SinceVersion(1)
508-
.SetDoc("TensorRT FP8 Dequantization - dequantizes FP8 input to FP16")
509-
.Input(0, "X", "Quantized input tensor in FP8", "T1")
510-
.Input(1, "scale", "Scale for dequantization in FP16", "T2")
511-
.Output(0, "Y", "Dequantized output tensor in FP16", "T2")
512-
.TypeConstraint("T1", {"tensor(float8e4m3fn)"}, "Input must be float8e4m3fn")
513-
.TypeConstraint("T2", {"tensor(float16)"}, "Scale and output must be float16")
514-
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
515-
// Output has same shape as input but FP16 type
516-
auto input_type = ctx.getInputType(0);
517-
if (input_type != nullptr && input_type->has_tensor_type()) {
518-
auto output_type = ctx.getOutputType(0);
519-
output_type->mutable_tensor_type()->set_elem_type(
520-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
521-
if (input_type->tensor_type().has_shape()) {
522-
*output_type->mutable_tensor_type()->mutable_shape() =
523-
input_type->tensor_type().shape();
505+
.SetName("TRT_FP8DequantizeLinear")
506+
.SetDomain("trt")
507+
.SinceVersion(1)
508+
.SetDoc("TensorRT FP8 Dequantization - dequantizes FP8 input to FP16")
509+
.Input(0, "X", "Quantized input tensor in FP8", "T1")
510+
.Input(1, "scale", "Scale for dequantization in FP16", "T2")
511+
.Output(0, "Y", "Dequantized output tensor in FP16", "T2")
512+
.TypeConstraint("T1", {"tensor(float8e4m3fn)"}, "Input must be float8e4m3fn")
513+
.TypeConstraint("T2", {"tensor(float16)"}, "Scale and output must be float16")
514+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
515+
// Output has same shape as input but FP16 type
516+
auto input_type = ctx.getInputType(0);
517+
if (input_type != nullptr && input_type->has_tensor_type()) {
518+
auto output_type = ctx.getOutputType(0);
519+
output_type->mutable_tensor_type()->set_elem_type(
520+
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
521+
if (input_type->tensor_type().has_shape()) {
522+
*output_type->mutable_tensor_type()->mutable_shape() =
523+
input_type->tensor_type().shape();
524+
}
524525
}
525-
}
526-
});
526+
});
527527

528528
return schemas;
529529
}
530530

531-
void CreateFP8CustomOpModel(const PathString& model_name, std::string graph_name) {
531+
void CreateFP8CustomOpModel(const PathString& model_name, const std::string& graph_name) {
532532
// Create custom schema registry for TRT operators
533533
auto custom_schema_registry = std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>();
534534

@@ -609,11 +609,8 @@ void CreateFP8CustomOpModel(const PathString& model_name, std::string graph_name
609609
graph.SetInputs({&input_arg});
610610
graph.SetOutputs({&output_arg});
611611

612-
status = graph.Resolve();
613-
ASSERT_TRUE(status.IsOK()) << "Graph Resolve failed: " << status.ErrorMessage();
614-
615-
status = Model::Save(model, model_name);
616-
ASSERT_TRUE(status.IsOK()) << "Model Save failed: " << status.ErrorMessage();
612+
ASSERT_STATUS_OK(graph.Resolve());
613+
ASSERT_STATUS_OK(Model::Save(model, model_name));
617614
}
618615
#endif // !defined(DISABLE_FLOAT8_TYPES)
619616

@@ -626,44 +623,44 @@ static std::vector<ONNX_NAMESPACE::OpSchema> CreateTRTFP4Schemas() {
626623
schemas.emplace_back();
627624
ONNX_NAMESPACE::OpSchema& fp4_quant_schema = schemas.back();
628625
fp4_quant_schema
629-
.SetName("TRT_FP4DynamicQuantize")
630-
.SetDomain("trt")
631-
.SinceVersion(1)
632-
.SetDoc("TensorRT FP4 Dynamic Quantization - quantizes FP16 input to FP4 with block-wise quantization")
633-
.Attr("axis", "Axis along which to quantize", ONNX_NAMESPACE::AttributeProto::INT, static_cast<int64_t>(-1))
634-
.Attr("block_size", "Block size for quantization", ONNX_NAMESPACE::AttributeProto::INT, static_cast<int64_t>(16))
635-
.Attr("scale_type", "Scale data type", ONNX_NAMESPACE::AttributeProto::INT, static_cast<int64_t>(17))
636-
.Input(0, "X", "Input tensor in FP16", "T1")
637-
.Input(1, "scale", "Scale for quantization in FP16", "T1")
638-
.Output(0, "Y_quantized", "Quantized output tensor in FP4", "T2")
639-
.Output(1, "Y_scale", "Computed scales in FP8", "T3")
640-
.TypeConstraint("T1", {"tensor(float16)"}, "Input and scale must be float16")
641-
.TypeConstraint("T2", {"tensor(float4e2m1)"}, "Quantized output must be float4e2m1")
642-
.TypeConstraint("T3", {"tensor(float8e4m3fn)"}, "Scale output must be float8e4m3fn")
643-
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
644-
// Output 0 (Y_quantized) has same shape as input but FP4 type
645-
auto input_type = ctx.getInputType(0);
646-
if (input_type != nullptr && input_type->has_tensor_type()) {
647-
auto output_type_0 = ctx.getOutputType(0);
648-
output_type_0->mutable_tensor_type()->set_elem_type(
649-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1);
650-
if (input_type->tensor_type().has_shape()) {
651-
*output_type_0->mutable_tensor_type()->mutable_shape() =
652-
input_type->tensor_type().shape();
626+
.SetName("TRT_FP4DynamicQuantize")
627+
.SetDomain("trt")
628+
.SinceVersion(1)
629+
.SetDoc("TensorRT FP4 Dynamic Quantization - quantizes FP16 input to FP4 with block-wise quantization")
630+
.Attr("axis", "Axis along which to quantize", ONNX_NAMESPACE::AttributeProto::INT, static_cast<int64_t>(-1))
631+
.Attr("block_size", "Block size for quantization", ONNX_NAMESPACE::AttributeProto::INT, static_cast<int64_t>(16))
632+
.Attr("scale_type", "Scale data type", ONNX_NAMESPACE::AttributeProto::INT, static_cast<int64_t>(17))
633+
.Input(0, "X", "Input tensor in FP16", "T1")
634+
.Input(1, "scale", "Scale for quantization in FP16", "T1")
635+
.Output(0, "Y_quantized", "Quantized output tensor in FP4", "T2")
636+
.Output(1, "Y_scale", "Computed scales in FP8", "T3")
637+
.TypeConstraint("T1", {"tensor(float16)"}, "Input and scale must be float16")
638+
.TypeConstraint("T2", {"tensor(float4e2m1)"}, "Quantized output must be float4e2m1")
639+
.TypeConstraint("T3", {"tensor(float8e4m3fn)"}, "Scale output must be float8e4m3fn")
640+
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
641+
// Output 0 (Y_quantized) has same shape as input but FP4 type
642+
auto input_type = ctx.getInputType(0);
643+
if (input_type != nullptr && input_type->has_tensor_type()) {
644+
auto output_type_0 = ctx.getOutputType(0);
645+
output_type_0->mutable_tensor_type()->set_elem_type(
646+
ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1);
647+
if (input_type->tensor_type().has_shape()) {
648+
*output_type_0->mutable_tensor_type()->mutable_shape() =
649+
input_type->tensor_type().shape();
650+
}
651+
652+
// Output 1 (Y_scale) shape depends on block_size and axis
653+
// For simplicity, we'll just set the type and let runtime handle shape
654+
auto output_type_1 = ctx.getOutputType(1);
655+
output_type_1->mutable_tensor_type()->set_elem_type(
656+
ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
653657
}
654-
655-
// Output 1 (Y_scale) shape depends on block_size and axis
656-
// For simplicity, we'll just set the type and let runtime handle shape
657-
auto output_type_1 = ctx.getOutputType(1);
658-
output_type_1->mutable_tensor_type()->set_elem_type(
659-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
660-
}
661-
});
658+
});
662659

663660
return schemas;
664661
}
665662

666-
void CreateFP4CustomOpModel(const PathString& model_name, std::string graph_name) {
663+
void CreateFP4CustomOpModel(const PathString& model_name, const std::string& graph_name) {
667664
// Create custom schema registry for TRT operators
668665
auto custom_schema_registry = std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>();
669666

@@ -819,11 +816,8 @@ void CreateFP4CustomOpModel(const PathString& model_name, std::string graph_name
819816
graph.SetInputs({&input_arg});
820817
graph.SetOutputs({&output_final});
821818

822-
status = graph.Resolve();
823-
ASSERT_TRUE(status.IsOK()) << "Graph Resolve failed: " << status.ErrorMessage();
824-
825-
status = Model::Save(model, model_name);
826-
ASSERT_TRUE(status.IsOK()) << "Model Save failed: " << status.ErrorMessage();
819+
ASSERT_STATUS_OK(graph.Resolve());
820+
ASSERT_STATUS_OK(Model::Save(model, model_name));
827821
}
828822
#endif // !defined(DISABLE_FLOAT4_TYPES) && !defined(DISABLE_FLOAT8_TYPES)
829823

onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Ort::IoBinding generate_io_binding(
158158
* (FLOAT16)
159159
*/
160160
void CreateFP8CustomOpModel(const PathString& model_name,
161-
std::string graph_name);
161+
const std::string& graph_name);
162162
#endif // !defined(DISABLE_FLOAT8_TYPES)
163163

164164
#if !defined(DISABLE_FLOAT4_TYPES) && !defined(DISABLE_FLOAT8_TYPES)
@@ -207,7 +207,7 @@ void CreateFP8CustomOpModel(const PathString& model_name,
207207
* (FLOAT16)
208208
*/
209209
void CreateFP4CustomOpModel(const PathString& model_name,
210-
std::string graph_name);
210+
const std::string& graph_name);
211211
#endif // !defined(DISABLE_FLOAT4_TYPES) && !defined(DISABLE_FLOAT8_TYPES)
212212

213213
} // namespace test

0 commit comments

Comments
 (0)