@@ -475,60 +475,60 @@ static std::vector<ONNX_NAMESPACE::OpSchema> CreateTRTFP8Schemas() {
475475 schemas.emplace_back ();
476476 ONNX_NAMESPACE::OpSchema& fp8_quant_schema = schemas.back ();
477477 fp8_quant_schema
478- .SetName (" TRT_FP8QuantizeLinear" )
479- .SetDomain (" trt" )
480- .SinceVersion (1 )
481- .SetDoc (" TensorRT FP8 Quantization - quantizes FP16 input to FP8" )
482- .Input (0 , " X" , " Input tensor in FP16" , " T1" )
483- .Input (1 , " scale" , " Scale for quantization in FP16" , " T1" )
484- .Output (0 , " Y" , " Quantized output tensor in FP8" , " T2" )
485- .TypeConstraint (" T1" , {" tensor(float16)" }, " Input and scale must be float16" )
486- .TypeConstraint (" T2" , {" tensor(float8e4m3fn)" }, " Output must be float8e4m3fn" )
487- .TypeAndShapeInferenceFunction ([](ONNX_NAMESPACE::InferenceContext& ctx) {
488- // Output has same shape as input but FP8 type
489- auto input_type = ctx.getInputType (0 );
490- if (input_type != nullptr && input_type->has_tensor_type ()) {
491- auto output_type = ctx.getOutputType (0 );
492- output_type->mutable_tensor_type ()->set_elem_type (
493- ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
494- if (input_type->tensor_type ().has_shape ()) {
495- *output_type->mutable_tensor_type ()->mutable_shape () =
496- input_type->tensor_type ().shape ();
478+ .SetName (" TRT_FP8QuantizeLinear" )
479+ .SetDomain (" trt" )
480+ .SinceVersion (1 )
481+ .SetDoc (" TensorRT FP8 Quantization - quantizes FP16 input to FP8" )
482+ .Input (0 , " X" , " Input tensor in FP16" , " T1" )
483+ .Input (1 , " scale" , " Scale for quantization in FP16" , " T1" )
484+ .Output (0 , " Y" , " Quantized output tensor in FP8" , " T2" )
485+ .TypeConstraint (" T1" , {" tensor(float16)" }, " Input and scale must be float16" )
486+ .TypeConstraint (" T2" , {" tensor(float8e4m3fn)" }, " Output must be float8e4m3fn" )
487+ .TypeAndShapeInferenceFunction ([](ONNX_NAMESPACE::InferenceContext& ctx) {
488+ // Output has same shape as input but FP8 type
489+ auto input_type = ctx.getInputType (0 );
490+ if (input_type != nullptr && input_type->has_tensor_type ()) {
491+ auto output_type = ctx.getOutputType (0 );
492+ output_type->mutable_tensor_type ()->set_elem_type (
493+ ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
494+ if (input_type->tensor_type ().has_shape ()) {
495+ *output_type->mutable_tensor_type ()->mutable_shape () =
496+ input_type->tensor_type ().shape ();
497+ }
497498 }
498- }
499- });
499+ });
500500
501501 // TRT_FP8DequantizeLinear schema
502502 schemas.emplace_back ();
503503 ONNX_NAMESPACE::OpSchema& fp8_dequant_schema = schemas.back ();
504504 fp8_dequant_schema
505- .SetName (" TRT_FP8DequantizeLinear" )
506- .SetDomain (" trt" )
507- .SinceVersion (1 )
508- .SetDoc (" TensorRT FP8 Dequantization - dequantizes FP8 input to FP16" )
509- .Input (0 , " X" , " Quantized input tensor in FP8" , " T1" )
510- .Input (1 , " scale" , " Scale for dequantization in FP16" , " T2" )
511- .Output (0 , " Y" , " Dequantized output tensor in FP16" , " T2" )
512- .TypeConstraint (" T1" , {" tensor(float8e4m3fn)" }, " Input must be float8e4m3fn" )
513- .TypeConstraint (" T2" , {" tensor(float16)" }, " Scale and output must be float16" )
514- .TypeAndShapeInferenceFunction ([](ONNX_NAMESPACE::InferenceContext& ctx) {
515- // Output has same shape as input but FP16 type
516- auto input_type = ctx.getInputType (0 );
517- if (input_type != nullptr && input_type->has_tensor_type ()) {
518- auto output_type = ctx.getOutputType (0 );
519- output_type->mutable_tensor_type ()->set_elem_type (
520- ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
521- if (input_type->tensor_type ().has_shape ()) {
522- *output_type->mutable_tensor_type ()->mutable_shape () =
523- input_type->tensor_type ().shape ();
505+ .SetName (" TRT_FP8DequantizeLinear" )
506+ .SetDomain (" trt" )
507+ .SinceVersion (1 )
508+ .SetDoc (" TensorRT FP8 Dequantization - dequantizes FP8 input to FP16" )
509+ .Input (0 , " X" , " Quantized input tensor in FP8" , " T1" )
510+ .Input (1 , " scale" , " Scale for dequantization in FP16" , " T2" )
511+ .Output (0 , " Y" , " Dequantized output tensor in FP16" , " T2" )
512+ .TypeConstraint (" T1" , {" tensor(float8e4m3fn)" }, " Input must be float8e4m3fn" )
513+ .TypeConstraint (" T2" , {" tensor(float16)" }, " Scale and output must be float16" )
514+ .TypeAndShapeInferenceFunction ([](ONNX_NAMESPACE::InferenceContext& ctx) {
515+ // Output has same shape as input but FP16 type
516+ auto input_type = ctx.getInputType (0 );
517+ if (input_type != nullptr && input_type->has_tensor_type ()) {
518+ auto output_type = ctx.getOutputType (0 );
519+ output_type->mutable_tensor_type ()->set_elem_type (
520+ ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
521+ if (input_type->tensor_type ().has_shape ()) {
522+ *output_type->mutable_tensor_type ()->mutable_shape () =
523+ input_type->tensor_type ().shape ();
524+ }
524525 }
525- }
526- });
526+ });
527527
528528 return schemas;
529529}
530530
531- void CreateFP8CustomOpModel (const PathString& model_name, std::string graph_name) {
531+ void CreateFP8CustomOpModel (const PathString& model_name, const std::string& graph_name) {
532532 // Create custom schema registry for TRT operators
533533 auto custom_schema_registry = std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>();
534534
@@ -609,11 +609,8 @@ void CreateFP8CustomOpModel(const PathString& model_name, std::string graph_name
609609 graph.SetInputs ({&input_arg});
610610 graph.SetOutputs ({&output_arg});
611611
612- status = graph.Resolve ();
613- ASSERT_TRUE (status.IsOK ()) << " Graph Resolve failed: " << status.ErrorMessage ();
614-
615- status = Model::Save (model, model_name);
616- ASSERT_TRUE (status.IsOK ()) << " Model Save failed: " << status.ErrorMessage ();
612+ ASSERT_STATUS_OK (graph.Resolve ());
613+ ASSERT_STATUS_OK (Model::Save (model, model_name));
617614}
618615#endif // !defined(DISABLE_FLOAT8_TYPES)
619616
@@ -626,44 +623,44 @@ static std::vector<ONNX_NAMESPACE::OpSchema> CreateTRTFP4Schemas() {
626623 schemas.emplace_back ();
627624 ONNX_NAMESPACE::OpSchema& fp4_quant_schema = schemas.back ();
628625 fp4_quant_schema
629- .SetName (" TRT_FP4DynamicQuantize" )
630- .SetDomain (" trt" )
631- .SinceVersion (1 )
632- .SetDoc (" TensorRT FP4 Dynamic Quantization - quantizes FP16 input to FP4 with block-wise quantization" )
633- .Attr (" axis" , " Axis along which to quantize" , ONNX_NAMESPACE::AttributeProto::INT, static_cast <int64_t >(-1 ))
634- .Attr (" block_size" , " Block size for quantization" , ONNX_NAMESPACE::AttributeProto::INT, static_cast <int64_t >(16 ))
635- .Attr (" scale_type" , " Scale data type" , ONNX_NAMESPACE::AttributeProto::INT, static_cast <int64_t >(17 ))
636- .Input (0 , " X" , " Input tensor in FP16" , " T1" )
637- .Input (1 , " scale" , " Scale for quantization in FP16" , " T1" )
638- .Output (0 , " Y_quantized" , " Quantized output tensor in FP4" , " T2" )
639- .Output (1 , " Y_scale" , " Computed scales in FP8" , " T3" )
640- .TypeConstraint (" T1" , {" tensor(float16)" }, " Input and scale must be float16" )
641- .TypeConstraint (" T2" , {" tensor(float4e2m1)" }, " Quantized output must be float4e2m1" )
642- .TypeConstraint (" T3" , {" tensor(float8e4m3fn)" }, " Scale output must be float8e4m3fn" )
643- .TypeAndShapeInferenceFunction ([](ONNX_NAMESPACE::InferenceContext& ctx) {
644- // Output 0 (Y_quantized) has same shape as input but FP4 type
645- auto input_type = ctx.getInputType (0 );
646- if (input_type != nullptr && input_type->has_tensor_type ()) {
647- auto output_type_0 = ctx.getOutputType (0 );
648- output_type_0->mutable_tensor_type ()->set_elem_type (
649- ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1);
650- if (input_type->tensor_type ().has_shape ()) {
651- *output_type_0->mutable_tensor_type ()->mutable_shape () =
652- input_type->tensor_type ().shape ();
626+ .SetName (" TRT_FP4DynamicQuantize" )
627+ .SetDomain (" trt" )
628+ .SinceVersion (1 )
629+ .SetDoc (" TensorRT FP4 Dynamic Quantization - quantizes FP16 input to FP4 with block-wise quantization" )
630+ .Attr (" axis" , " Axis along which to quantize" , ONNX_NAMESPACE::AttributeProto::INT, static_cast <int64_t >(-1 ))
631+ .Attr (" block_size" , " Block size for quantization" , ONNX_NAMESPACE::AttributeProto::INT, static_cast <int64_t >(16 ))
632+ .Attr (" scale_type" , " Scale data type" , ONNX_NAMESPACE::AttributeProto::INT, static_cast <int64_t >(17 ))
633+ .Input (0 , " X" , " Input tensor in FP16" , " T1" )
634+ .Input (1 , " scale" , " Scale for quantization in FP16" , " T1" )
635+ .Output (0 , " Y_quantized" , " Quantized output tensor in FP4" , " T2" )
636+ .Output (1 , " Y_scale" , " Computed scales in FP8" , " T3" )
637+ .TypeConstraint (" T1" , {" tensor(float16)" }, " Input and scale must be float16" )
638+ .TypeConstraint (" T2" , {" tensor(float4e2m1)" }, " Quantized output must be float4e2m1" )
639+ .TypeConstraint (" T3" , {" tensor(float8e4m3fn)" }, " Scale output must be float8e4m3fn" )
640+ .TypeAndShapeInferenceFunction ([](ONNX_NAMESPACE::InferenceContext& ctx) {
641+ // Output 0 (Y_quantized) has same shape as input but FP4 type
642+ auto input_type = ctx.getInputType (0 );
643+ if (input_type != nullptr && input_type->has_tensor_type ()) {
644+ auto output_type_0 = ctx.getOutputType (0 );
645+ output_type_0->mutable_tensor_type ()->set_elem_type (
646+ ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1);
647+ if (input_type->tensor_type ().has_shape ()) {
648+ *output_type_0->mutable_tensor_type ()->mutable_shape () =
649+ input_type->tensor_type ().shape ();
650+ }
651+
652+ // Output 1 (Y_scale) shape depends on block_size and axis
653+ // For simplicity, we'll just set the type and let runtime handle shape
654+ auto output_type_1 = ctx.getOutputType (1 );
655+ output_type_1->mutable_tensor_type ()->set_elem_type (
656+ ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
653657 }
654-
655- // Output 1 (Y_scale) shape depends on block_size and axis
656- // For simplicity, we'll just set the type and let runtime handle shape
657- auto output_type_1 = ctx.getOutputType (1 );
658- output_type_1->mutable_tensor_type ()->set_elem_type (
659- ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN);
660- }
661- });
658+ });
662659
663660 return schemas;
664661}
665662
666- void CreateFP4CustomOpModel (const PathString& model_name, std::string graph_name) {
663+ void CreateFP4CustomOpModel (const PathString& model_name, const std::string& graph_name) {
667664 // Create custom schema registry for TRT operators
668665 auto custom_schema_registry = std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>();
669666
@@ -819,11 +816,8 @@ void CreateFP4CustomOpModel(const PathString& model_name, std::string graph_name
819816 graph.SetInputs ({&input_arg});
820817 graph.SetOutputs ({&output_final});
821818
822- status = graph.Resolve ();
823- ASSERT_TRUE (status.IsOK ()) << " Graph Resolve failed: " << status.ErrorMessage ();
824-
825- status = Model::Save (model, model_name);
826- ASSERT_TRUE (status.IsOK ()) << " Model Save failed: " << status.ErrorMessage ();
819+ ASSERT_STATUS_OK (graph.Resolve ());
820+ ASSERT_STATUS_OK (Model::Save (model, model_name));
827821}
828822#endif // !defined(DISABLE_FLOAT4_TYPES) && !defined(DISABLE_FLOAT8_TYPES)
829823
0 commit comments