diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 62210d65848d1..e2a8005aba1da 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -85,6 +85,25 @@ struct ShutdownProtobuf { namespace onnxruntime { +// Helper function to check if a data type is supported by input output nodes ofNvTensorRTRTX EP +static bool IsSupportedInputOutputDataType(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point + return true; + default: + return false; + } +} + // Helper function to check if a data type is supported by NvTensorRTRTX EP static bool IsSupportedDataType(ONNXTensorElementDataType data_type) { switch (data_type) { @@ -98,6 +117,7 @@ static bool IsSupportedDataType(ONNXTensorElementDataType data_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: // kDOUBLE - 64-bit floating point return true; default: return false; @@ -1939,6 +1959,28 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, #endif model_path_[sizeof(model_path_) - 1] = '\0'; + // Early return if the model has unsupported input/output data types + for (const auto* input : graph.GetInputs()) { + const auto* tp = input->TypeAsProto(); + if (tp && tp->has_tensor_type()) { + auto data_type = static_cast(tp->tensor_type().elem_type()); + if (!IsSupportedInputOutputDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unsupported data type " << GetDataTypeName(data_type) << " for input node: " << input->Name(); + return result; + } + } + } + for (const auto* output : graph.GetOutputs()) { + const auto* tp = output->TypeAsProto(); + if (tp && tp->has_tensor_type()) { + auto data_type = static_cast(tp->tensor_type().elem_type()); + if (!IsSupportedInputOutputDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unsupported data type " << GetDataTypeName(data_type) << " for output node: " << output->Name(); + return result; + } + } + } + const int number_of_ort_nodes = graph.NumberOfNodes(); const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index aa0d7d90635e4..6806d851958ea 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -143,6 +143,27 @@ Ort::Status CompileEpContextModel(const Ort::Env& env, const perftest::Performan std::unordered_map provider_options; session_options.AppendExecutionProvider(provider_name, provider_options); + // free dim override + if (!test_config.run_config.free_dim_name_overrides.empty()) { + for (auto const& dim_override : test_config.run_config.free_dim_name_overrides) { + if (g_ort->AddFreeDimensionOverrideByName(session_options, ToUTF8String(dim_override.first).c_str(), dim_override.second) != nullptr) { + fprintf(stderr, "AddFreeDimensionOverrideByName failed for named dimension: %s\n", ToUTF8String(dim_override.first).c_str()); + } else { + fprintf(stdout, "Overriding dimension with name, %s, to %d\n", ToUTF8String(dim_override.first).c_str(), (int)dim_override.second); + } + } + } + + if (!test_config.run_config.free_dim_denotation_overrides.empty()) { + for (auto const& dim_override : test_config.run_config.free_dim_denotation_overrides) { + if (g_ort->AddFreeDimensionOverride(session_options, ToUTF8String(dim_override.first).c_str(), dim_override.second) != nullptr) { + fprintf(stderr, "AddFreeDimensionOverride failed for dimension denotation: %s\n", ToUTF8String(dim_override.first).c_str()); + } else { + fprintf(stdout, "Overriding dimension with denotation, %s, to %d\n", ToUTF8String(dim_override.first).c_str(), (int)dim_override.second); + } + } + } + Ort::ModelCompilationOptions model_compile_options(env, session_options); model_compile_options.SetEpContextEmbedMode(test_config.run_config.compile_binary_embed); model_compile_options.SetInputModelPath(test_config.model_info.model_file_path.c_str());