diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 15e7fad0d4a1a..5fd2f60a8dbfe 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3730,8 +3730,24 @@ Status Graph::ConvertInitializersIntoOrtValues() { auto& graph_proto = *graph.graph_proto_; for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) { auto& tensor_proto = *graph_proto.mutable_initializer(i); + if (utils::HasExternalData(tensor_proto)) { - continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize + if (utils::HasExternalDataInMemory(tensor_proto)) { + // This can happen when the model is created with ModelEditor. + // We want to guard against malicious models with arbitrary in-memory references. + if (OrtValue v; GetOrtValueInitializer(tensor_proto.name(), v)) { + ORT_RETURN_IF_NOT(graph_utils::CheckInMemoryDataMatch(tensor_proto, v.Get()), + "In-memory data mismatch for initializer: ", tensor_proto.name(), + " this is an invalid model"); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "The model contains initializers with arbitrary in-memory references.", + "This is an invalid model."); + } + } + // ignore data on disk, that will be loaded either by EP or at session_state finalize + // ignore valid in-memory references + continue; } size_t size_in_bytes = 0; @@ -3830,6 +3846,10 @@ Status Graph::AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor_p const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); ORT_RETURN_IF_NOT(proto_shape == tensor.Shape(), "Shape mismatch with ortvalue_initializer"); + const bool data_ptr_match = graph_utils::CheckInMemoryDataMatch(tensor_proto, tensor); + ORT_RETURN_IF_NOT(data_ptr_match, + "In-memory data pointer mismatch between tensor proto and ortvalue_initializer"); + ortvalue_initializers_.insert_or_assign(tensor_proto.name(), ortvalue_initializer); } else { ORT_ENFORCE(ortvalue_initializers_.count(tensor_proto.name()) == 0, diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index a39b9c89e5898..7230a22246318 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4761,3 +4761,36 @@ TEST(CApiTest, custom_cast) { inputs, "output", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr); } + +TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) { + // Attempt to create an ORT session with the malicious model + // This should fail due to the invalid external data reference + constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_evil_weights.onnx"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); + Ort::SessionOptions session_options; + + bool exception_thrown = false; + std::string exception_message; + + try { + // This should throw an exception due to malicious external data + Ort::Session session(env, model_path, session_options); + } catch (const Ort::Exception& e) { + exception_thrown = true; + exception_message = e.what(); + } catch (const std::exception& e) { + exception_thrown = true; + exception_message = e.what(); + } + + // Verify that loading the model failed + EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data"; + + // Verify that the exception message indicates security or external data issues + EXPECT_TRUE(exception_message.find("in-memory") != std::string::npos || + exception_message.find("references") != std::string::npos || + exception_message.find("invalid") != std::string::npos || + exception_message.find("model") != std::string::npos) + << "Exception message should indicate external data or security issue. Got: " << exception_message; +} diff --git a/onnxruntime/test/testdata/test_evil_weights.onnx b/onnxruntime/test/testdata/test_evil_weights.onnx new file mode 100644 index 0000000000000..7f538fc1df1aa Binary files /dev/null and b/onnxruntime/test/testdata/test_evil_weights.onnx differ diff --git a/onnxruntime/test/testdata/test_evil_weights.py b/onnxruntime/test/testdata/test_evil_weights.py new file mode 100644 index 0000000000000..61be312474612 --- /dev/null +++ b/onnxruntime/test/testdata/test_evil_weights.py @@ -0,0 +1,50 @@ +import onnx + + +def create_exp_model(): + inputs = [] + nodes = [] + tensors = [] + outputs = [] + + input_ = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.INT64, [None, None]) + inputs.append(input_) + + evil_tensor = onnx.helper.make_tensor( + name="evil_weights", + data_type=onnx.TensorProto.INT64, + # dims=[100, 100], + dims=[10], + vals=[], + ) + tensors.append(evil_tensor) + evil_tensor.data_location = onnx.TensorProto.EXTERNAL + entry1 = evil_tensor.external_data.add() + entry1.key = "location" + entry1.value = "*/_ORT_MEM_ADDR_/*" + entry2 = evil_tensor.external_data.add() + entry2.key = "offset" + entry2.value = "4194304" + entry2.value = "12230656" + entry3 = evil_tensor.external_data.add() + entry3.key = "length" + entry3.value = "80" + + tensors.append(onnx.helper.make_tensor(name="0x1", data_type=onnx.TensorProto.INT64, dims=[1], vals=[0x1])) + nodes.append(onnx.helper.make_node(op_type="Add", inputs=["evil_weights", "0x1"], outputs=["output"])) + + outputs.append(onnx.helper.make_tensor_value_info("output", onnx.TensorProto.INT64, [])) + + graph = onnx.helper.make_graph(nodes, "test", inputs, outputs, tensors) + model = onnx.helper.make_model( + graph, + opset_imports=[onnx.helper.make_opsetid("", 18), onnx.helper.make_opsetid("ai.onnx.ml", 3)], + ir_version=11, + ) + + return model + + +if __name__ == "__main__": + model = create_exp_model() + onnx.save(model, "test_evil_weights.onnx")