Skip to content

Commit 0318be2

Browse files
[TRTRTX EP] Support custom-ops used in NVFP4 recipe (#26555)
### Description - In this change, enabling FP4 datatype and NVFP4 recipe's custom ops, in TRTRTX EP. ### Motivation and Context - NVIDIA's NVFP4 quantization recipe currently uses custom-ops for operations like FP4 dynamic & double quantization, FP8 Q/DQ in MHA etc. These custom ops are natively supported (i.e. without requiring plugin). - An NVFP4 model (say NVFP4 Flux or SD model) would be able to run through CLI tool like tensorrt_rtx but it will fail on running it through onnxruntime's TRTRTX EP - due to unrecognized custom ops and FP4 datatype. - So, to enable running the NVFP4 model through onnxruntime's TRTRTX EP, we are supporting FP4 datatype and NVFP4 related custom ops in TRTRTX EP. - Validated the change with following settings: SD3.5-medium (with FP4 transformer) + optimum-onnxruntime SD pipeline + Windows 11 22621 + RTX 5090 + text-to-image modality. The inference run did produce image for the text input and no errors were thrown.
1 parent d1abad0 commit 0318be2

File tree

7 files changed

+609
-2
lines changed

7 files changed

+609
-2
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
118118
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
119119
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
120120
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: // kDOUBLE - 64-bit floating point
121+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1: // kFP4 - 4-bit floating point
121122
return true;
122123
default:
123124
return false;
@@ -692,6 +693,7 @@ Status BindContextOutput(Ort::KernelContext& ctx,
692693
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
693694
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
694695
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
696+
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, uint8_t)
695697
default: {
696698
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
697699
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
@@ -756,6 +758,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
756758
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
757759
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
758760
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
761+
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, uint8_t)
759762
default: {
760763
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
761764
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,15 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose);
4141
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
4242
static std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
4343
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
44+
static std::unique_ptr<OrtCustomOpDomain> native_custom_op_domain = std::make_unique<OrtCustomOpDomain>();
45+
static std::vector<std::unique_ptr<TensorRTCustomOp>> native_custom_op_list;
4446
static std::mutex mutex;
4547
std::lock_guard<std::mutex> lock(mutex);
4648
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
4749
domain_list.push_back(custom_op_domain.get());
50+
if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) {
51+
domain_list.push_back(native_custom_op_domain.get());
52+
}
4853
return Status::OK();
4954
}
5055

@@ -136,6 +141,19 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
136141
} catch (const std::exception&) {
137142
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
138143
}
144+
145+
// Register native custom ops (register these independent of TRT plugin library availability)
146+
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
147+
int num_native_custom_ops = std::size(native_custom_ops_names);
148+
149+
for (int i = 0; i < num_native_custom_ops; i++) {
150+
native_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
151+
native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
152+
native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
153+
}
154+
155+
native_custom_op_domain->domain_ = "trt";
156+
domain_list.push_back(native_custom_op_domain.get());
139157
return Status::OK();
140158
}
141159

onnxruntime/core/session/custom_ops.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ static constexpr uint32_t min_ort_version_with_compute_v2_support = 16;
4949
static constexpr uint32_t min_ort_version_with_shape_inference = 17;
5050
#endif
5151

52-
#if !defined(DISABLE_FLOAT8_TYPES)
52+
#if !defined(DISABLE_FLOAT8_TYPES) && !defined(DISABLE_FLOAT4_TYPES)
53+
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv11()
54+
#elif !defined(DISABLE_FLOAT8_TYPES)
5355
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv10()
5456
#else
5557
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv4()

onnxruntime/python/onnxruntime_pybind_mlvalue.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,17 @@ int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type) {
447447
{DataTypeImpl::GetType<int64_t>(), NPY_LONGLONG},
448448
{DataTypeImpl::GetType<uint64_t>(), NPY_ULONGLONG},
449449
{DataTypeImpl::GetType<std::string>(), NPY_OBJECT},
450+
#if !defined(DISABLE_FLOAT4_TYPES)
451+
{DataTypeImpl::GetType<Float4E2M1x2>(), NPY_UINT8},
452+
#endif
453+
#if !defined(DISABLE_FLOAT8_TYPES)
454+
{DataTypeImpl::GetType<Float8E4M3FN>(), NPY_UINT8},
455+
#endif
450456
};
451457

452458
const auto it = type_map.find(tensor_type);
453459
if (it == type_map.end()) {
454-
throw std::runtime_error("No corresponding Numpy type for Tensor Type.");
460+
throw std::runtime_error("No corresponding Numpy type for Tensor Type. " + std::string(DataTypeImpl::ToString(tensor_type)));
455461
} else {
456462
return it->second;
457463
}

onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/session/inference_session.h"
55
#include "test/providers/provider_test_utils.h"
66
#include "test/unittest_util/framework_test_utils.h"
7+
#include "test/util/include/default_providers.h"
78

89
#include "test/util/include/scoped_env_vars.h"
910
#include "test/common/trt_op_test_utils.h"
@@ -440,6 +441,140 @@ TEST(NvExecutionProviderTest, DataTransfer) {
440441
device_tensor = Ort::Value();
441442
}
442443

444+
TEST(NvExecutionProviderTest, FP8CustomOpModel) {
445+
PathString model_name = ORT_TSTR("nv_execution_provider_fp8_quantize_dequantize_test.onnx");
446+
clearFileIfExists(model_name);
447+
std::string graph_name = "nv_execution_provider_fp8_quantize_dequantize_graph";
448+
449+
// Create a model with TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear (FP16 -> FP8 -> FP16, per-tensor quantization)
450+
CreateFP8CustomOpModel(model_name, graph_name);
451+
452+
// Verify the model file was created
453+
ASSERT_TRUE(std::filesystem::exists(model_name));
454+
455+
// Create session and register execution provider explicitly
456+
// This ensures custom ops are registered before model is loaded
457+
SessionOptions so;
458+
so.session_logid = "NvExecutionProviderTest.FP8CustomOpModel";
459+
InferenceSession session_object{so, GetEnvironment()};
460+
461+
// Register TRTRTX EP - this will register custom ops
462+
std::unique_ptr<IExecutionProvider> execution_provider = DefaultNvTensorRTRTXExecutionProvider();
463+
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider)));
464+
465+
// Load and initialize model
466+
ASSERT_STATUS_OK(session_object.Load(model_name));
467+
ASSERT_STATUS_OK(session_object.Initialize());
468+
469+
// Create input data (FP16, shape [4, 64])
470+
std::vector<MLFloat16> input_data(4 * 64);
471+
for (size_t i = 0; i < input_data.size(); ++i) {
472+
input_data[i] = MLFloat16(static_cast<float>(i % 100) / 100.0f);
473+
}
474+
475+
// Create input tensor
476+
std::vector<int64_t> input_shape = {4, 64};
477+
OrtValue input_tensor;
478+
Tensor::InitOrtValue(DataTypeImpl::GetType<MLFloat16>(), TensorShape(input_shape),
479+
input_data.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator),
480+
input_tensor);
481+
482+
// Prepare feeds
483+
NameMLValMap feeds;
484+
feeds.insert(std::make_pair("X", input_tensor));
485+
486+
// Prepare outputs
487+
std::vector<std::string> output_names;
488+
output_names.push_back("Y");
489+
490+
// Run inference
491+
std::vector<OrtValue> fetches;
492+
RunOptions run_options;
493+
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
494+
495+
// Verify output tensor is valid
496+
ASSERT_EQ(fetches.size(), 1u);
497+
ASSERT_TRUE(fetches[0].IsTensor());
498+
499+
const auto& output_tensor = fetches[0].Get<Tensor>();
500+
auto output_shape = output_tensor.Shape();
501+
ASSERT_EQ(output_shape.NumDimensions(), 2u);
502+
ASSERT_EQ(output_shape[0], 4);
503+
ASSERT_EQ(output_shape[1], 64);
504+
505+
// Verify output is FLOAT16
506+
ASSERT_TRUE(output_tensor.IsDataType<MLFloat16>());
507+
508+
LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP8 custom ops model run completed successfully";
509+
}
510+
511+
TEST(NvExecutionProviderTest, FP4CustomOpModel) {
512+
PathString model_name = ORT_TSTR("nv_execution_provider_fp4_dynamic_quantize_test.onnx");
513+
clearFileIfExists(model_name);
514+
std::string graph_name = "nv_execution_provider_fp4_dynamic_quantize_graph";
515+
516+
// Create a model with TRT_FP4DynamicQuantize node
517+
CreateFP4CustomOpModel(model_name, graph_name);
518+
519+
// Verify the model file was created
520+
ASSERT_TRUE(std::filesystem::exists(model_name));
521+
522+
// Create session and register execution provider explicitly
523+
// This ensures custom ops are registered before model is loaded
524+
SessionOptions so;
525+
so.session_logid = "NvExecutionProviderTest.FP4CustomOpModel";
526+
InferenceSession session_object{so, GetEnvironment()};
527+
528+
// Register TRTRTX EP - this will register custom ops
529+
std::unique_ptr<IExecutionProvider> execution_provider = DefaultNvTensorRTRTXExecutionProvider();
530+
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider)));
531+
532+
// Load and initialize model
533+
ASSERT_STATUS_OK(session_object.Load(model_name));
534+
ASSERT_STATUS_OK(session_object.Initialize());
535+
536+
// Create input data (FP16, shape [64, 64])
537+
std::vector<MLFloat16> input_data(64 * 64);
538+
for (size_t i = 0; i < input_data.size(); ++i) {
539+
input_data[i] = MLFloat16(static_cast<float>(i % 100) / 100.0f);
540+
}
541+
542+
// Create input tensor
543+
std::vector<int64_t> input_shape = {64, 64};
544+
OrtValue input_tensor;
545+
Tensor::InitOrtValue(DataTypeImpl::GetType<MLFloat16>(), TensorShape(input_shape),
546+
input_data.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator),
547+
input_tensor);
548+
549+
// Prepare feeds
550+
NameMLValMap feeds;
551+
feeds.insert(std::make_pair("X", input_tensor));
552+
553+
// Prepare outputs
554+
std::vector<std::string> output_names;
555+
output_names.push_back("X_dequantized");
556+
557+
// Run inference
558+
std::vector<OrtValue> fetches;
559+
RunOptions run_options;
560+
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
561+
562+
// Verify output tensor is valid
563+
ASSERT_EQ(fetches.size(), 1u);
564+
ASSERT_TRUE(fetches[0].IsTensor());
565+
566+
const auto& output_tensor = fetches[0].Get<Tensor>();
567+
auto output_shape = output_tensor.Shape();
568+
ASSERT_EQ(output_shape.NumDimensions(), 2u);
569+
ASSERT_EQ(output_shape[0], 64);
570+
ASSERT_EQ(output_shape[1], 64);
571+
572+
// Verify output is FLOAT16
573+
ASSERT_TRUE(output_tensor.IsDataType<MLFloat16>());
574+
575+
LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP4 dynamic quantize model run completed successfully";
576+
}
577+
443578
#endif
444579

445580
} // namespace test

0 commit comments

Comments
 (0)