Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: // kDOUBLE - 64-bit floating point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1: // kFP4 - 4-bit floating point
return true;
default:
return false;
Expand Down Expand Up @@ -692,6 +693,7 @@ Status BindContextOutput(Ort::KernelContext& ctx,
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, uint8_t)
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
Expand Down Expand Up @@ -756,6 +758,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, uint8_t)
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,15 @@
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
static std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
static std::unique_ptr<OrtCustomOpDomain> native_custom_op_domain = std::make_unique<OrtCustomOpDomain>();

Check warning on line 44 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc:44: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

Check warning on line 44 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc:44: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
static std::vector<std::unique_ptr<TensorRTCustomOp>> native_custom_op_list;

Check warning on line 45 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc:45: Add #include <vector> for vector<> [build/include_what_you_use] [4]

Check warning on line 45 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc:45: Add #include <vector> for vector<> [build/include_what_you_use] [4]
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain.get());
if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(native_custom_op_domain.get());
}
return Status::OK();
}

Expand Down Expand Up @@ -136,6 +141,19 @@
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}

// Register native custom ops (register these independent of TRT plugin library availability)
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
int num_native_custom_ops = std::size(native_custom_ops_names);

for (int i = 0; i < num_native_custom_ops; i++) {
native_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
}

native_custom_op_domain->domain_ = "trt";
domain_list.push_back(native_custom_op_domain.get());
return Status::OK();
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ static constexpr uint32_t min_ort_version_with_compute_v2_support = 16;
static constexpr uint32_t min_ort_version_with_shape_inference = 17;
#endif

#if !defined(DISABLE_FLOAT8_TYPES)
#if !defined(DISABLE_FLOAT8_TYPES) && !defined(DISABLE_FLOAT4_TYPES)
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv11()
#elif !defined(DISABLE_FLOAT8_TYPES)
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv10()
#else
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv4()
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,17 @@ int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type) {
{DataTypeImpl::GetType<int64_t>(), NPY_LONGLONG},
{DataTypeImpl::GetType<uint64_t>(), NPY_ULONGLONG},
{DataTypeImpl::GetType<std::string>(), NPY_OBJECT},
#if !defined(DISABLE_FLOAT4_TYPES)
{DataTypeImpl::GetType<Float4E2M1x2>(), NPY_UINT8},
#endif
#if !defined(DISABLE_FLOAT8_TYPES)
{DataTypeImpl::GetType<Float8E4M3FN>(), NPY_UINT8},
#endif
};

const auto it = type_map.find(tensor_type);
if (it == type_map.end()) {
throw std::runtime_error("No corresponding Numpy type for Tensor Type.");
throw std::runtime_error("No corresponding Numpy type for Tensor Type. " + std::string(DataTypeImpl::ToString(tensor_type)));
} else {
return it->second;
}
Expand Down
135 changes: 135 additions & 0 deletions onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/session/inference_session.h"
#include "test/providers/provider_test_utils.h"
#include "test/unittest_util/framework_test_utils.h"
#include "test/util/include/default_providers.h"

#include "test/util/include/scoped_env_vars.h"
#include "test/common/trt_op_test_utils.h"
Expand Down Expand Up @@ -440,6 +441,140 @@ TEST(NvExecutionProviderTest, DataTransfer) {
device_tensor = Ort::Value();
}

TEST(NvExecutionProviderTest, FP8CustomOpModel) {
PathString model_name = ORT_TSTR("nv_execution_provider_fp8_quantize_dequantize_test.onnx");
clearFileIfExists(model_name);
std::string graph_name = "nv_execution_provider_fp8_quantize_dequantize_graph";

// Create a model with TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear (FP16 -> FP8 -> FP16, per-tensor quantization)
CreateFP8CustomOpModel(model_name, graph_name);

// Verify the model file was created
ASSERT_TRUE(std::filesystem::exists(model_name));

// Create session and register execution provider explicitly
// This ensures custom ops are registered before model is loaded
SessionOptions so;
so.session_logid = "NvExecutionProviderTest.FP8CustomOpModel";
InferenceSession session_object{so, GetEnvironment()};

// Register TRTRTX EP - this will register custom ops
std::unique_ptr<IExecutionProvider> execution_provider = DefaultNvTensorRTRTXExecutionProvider();
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider)));

// Load and initialize model
ASSERT_STATUS_OK(session_object.Load(model_name));
ASSERT_STATUS_OK(session_object.Initialize());

// Create input data (FP16, shape [4, 64])
std::vector<MLFloat16> input_data(4 * 64);
for (size_t i = 0; i < input_data.size(); ++i) {
input_data[i] = MLFloat16(static_cast<float>(i % 100) / 100.0f);
}

// Create input tensor
std::vector<int64_t> input_shape = {4, 64};
OrtValue input_tensor;
Tensor::InitOrtValue(DataTypeImpl::GetType<MLFloat16>(), TensorShape(input_shape),
input_data.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator),
input_tensor);

// Prepare feeds
NameMLValMap feeds;
feeds.insert(std::make_pair("X", input_tensor));

// Prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Y");

// Run inference
std::vector<OrtValue> fetches;
RunOptions run_options;
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));

// Verify output tensor is valid
ASSERT_EQ(fetches.size(), 1u);
ASSERT_TRUE(fetches[0].IsTensor());

const auto& output_tensor = fetches[0].Get<Tensor>();
auto output_shape = output_tensor.Shape();
ASSERT_EQ(output_shape.NumDimensions(), 2u);
ASSERT_EQ(output_shape[0], 4);
ASSERT_EQ(output_shape[1], 64);

// Verify output is FLOAT16
ASSERT_TRUE(output_tensor.IsDataType<MLFloat16>());

LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP8 custom ops model run completed successfully";
}

TEST(NvExecutionProviderTest, FP4CustomOpModel) {
PathString model_name = ORT_TSTR("nv_execution_provider_fp4_dynamic_quantize_test.onnx");
clearFileIfExists(model_name);
std::string graph_name = "nv_execution_provider_fp4_dynamic_quantize_graph";

// Create a model with TRT_FP4DynamicQuantize node
CreateFP4CustomOpModel(model_name, graph_name);

// Verify the model file was created
ASSERT_TRUE(std::filesystem::exists(model_name));

// Create session and register execution provider explicitly
// This ensures custom ops are registered before model is loaded
SessionOptions so;
so.session_logid = "NvExecutionProviderTest.FP4CustomOpModel";
InferenceSession session_object{so, GetEnvironment()};

// Register TRTRTX EP - this will register custom ops
std::unique_ptr<IExecutionProvider> execution_provider = DefaultNvTensorRTRTXExecutionProvider();
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider)));

// Load and initialize model
ASSERT_STATUS_OK(session_object.Load(model_name));
ASSERT_STATUS_OK(session_object.Initialize());

// Create input data (FP16, shape [64, 64])
std::vector<MLFloat16> input_data(64 * 64);
for (size_t i = 0; i < input_data.size(); ++i) {
input_data[i] = MLFloat16(static_cast<float>(i % 100) / 100.0f);
}

// Create input tensor
std::vector<int64_t> input_shape = {64, 64};
OrtValue input_tensor;
Tensor::InitOrtValue(DataTypeImpl::GetType<MLFloat16>(), TensorShape(input_shape),
input_data.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator),
input_tensor);

// Prepare feeds
NameMLValMap feeds;
feeds.insert(std::make_pair("X", input_tensor));

// Prepare outputs
std::vector<std::string> output_names;
output_names.push_back("X_dequantized");

// Run inference
std::vector<OrtValue> fetches;
RunOptions run_options;
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));

// Verify output tensor is valid
ASSERT_EQ(fetches.size(), 1u);
ASSERT_TRUE(fetches[0].IsTensor());

const auto& output_tensor = fetches[0].Get<Tensor>();
auto output_shape = output_tensor.Shape();
ASSERT_EQ(output_shape.NumDimensions(), 2u);
ASSERT_EQ(output_shape[0], 64);
ASSERT_EQ(output_shape[1], 64);

// Verify output is FLOAT16
ASSERT_TRUE(output_tensor.IsDataType<MLFloat16>());

LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP4 dynamic quantize model run completed successfully";
}

#endif

} // namespace test
Expand Down
Loading
Loading