Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3730,8 +3730,24 @@ Status Graph::ConvertInitializersIntoOrtValues() {
auto& graph_proto = *graph.graph_proto_;
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
auto& tensor_proto = *graph_proto.mutable_initializer(i);

if (utils::HasExternalData(tensor_proto)) {
continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize
if (utils::HasExternalDataInMemory(tensor_proto)) {
// This can happen when the model is created with ModelEditor.
// We want to guard against malicious models with arbitrary in-memory references.
if (OrtValue v; GetOrtValueInitializer(tensor_proto.name(), v)) {
ORT_RETURN_IF_NOT(graph_utils::CheckInMemoryDataMatch(tensor_proto, v.Get<Tensor>()),
"In-memory data mismatch for initializer: ", tensor_proto.name(),
" this is an invalid model");
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"The model contains initializers with arbitrary in-memory references.",
"This is an invalid model.");
}
}
// ignore data on disk, that will be loaded either by EP or at session_state finalize
// ignore valid in-memory references
continue;
}

size_t size_in_bytes = 0;
Expand Down Expand Up @@ -3830,6 +3846,10 @@ Status Graph::AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor_p
const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto);
ORT_RETURN_IF_NOT(proto_shape == tensor.Shape(), "Shape mismatch with ortvalue_initializer");

const bool data_ptr_match = graph_utils::CheckInMemoryDataMatch(tensor_proto, tensor);
ORT_RETURN_IF_NOT(data_ptr_match,
"In-memory data pointer mismatch between tensor proto and ortvalue_initializer");

ortvalue_initializers_.insert_or_assign(tensor_proto.name(), ortvalue_initializer);
} else {
ORT_ENFORCE(ortvalue_initializers_.count(tensor_proto.name()) == 0,
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4761,3 +4761,36 @@ TEST(CApiTest, custom_cast) {
inputs, "output", expected_dims_y, expected_values_y, 0,
custom_op_domain, nullptr);
}

TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
// Attempt to create an ORT session with the malicious model
// This should fail due to the invalid external data reference
constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_evil_weights.onnx");

Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;

bool exception_thrown = false;
std::string exception_message;

try {
// This should throw an exception due to malicious external data
Ort::Session session(env, model_path, session_options);
} catch (const Ort::Exception& e) {
exception_thrown = true;
exception_message = e.what();
} catch (const std::exception& e) {
exception_thrown = true;
exception_message = e.what();
}

// Verify that loading the model failed
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data";

// Verify that the exception message indicates security or external data issues
EXPECT_TRUE(exception_message.find("in-memory") != std::string::npos ||
exception_message.find("references") != std::string::npos ||
exception_message.find("invalid") != std::string::npos ||
exception_message.find("model") != std::string::npos)
<< "Exception message should indicate external data or security issue. Got: " << exception_message;
}
Binary file added onnxruntime/test/testdata/test_evil_weights.onnx
Binary file not shown.
49 changes: 49 additions & 0 deletions onnxruntime/test/testdata/test_evil_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import onnx


def create_exp_model():
inputs = []
nodes = []
tensors = []
outputs = []

input_ = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.INT64, [None, None])
inputs.append(input_)

evil_tensor = onnx.helper.make_tensor(
name="evil_weights",
data_type=onnx.TensorProto.INT64,
# dims=[100, 100],
dims=[10],
vals=[],
)
tensors.append(evil_tensor)
evil_tensor.data_location = onnx.TensorProto.EXTERNAL
entry1 = evil_tensor.external_data.add()
entry1.key = "location"
entry1.value = "*/_ORT_MEM_ADDR_/*"
entry2 = evil_tensor.external_data.add()
entry2.key = "offset"
entry2.value = "12230656"
entry3 = evil_tensor.external_data.add()
entry3.key = "length"
entry3.value = "80"

tensors.append(onnx.helper.make_tensor(name="0x1", data_type=onnx.TensorProto.INT64, dims=[1], vals=[0x1]))
nodes.append(onnx.helper.make_node(op_type="Add", inputs=["evil_weights", "0x1"], outputs=["output"]))

outputs.append(onnx.helper.make_tensor_value_info("output", onnx.TensorProto.INT64, [10]))

graph = onnx.helper.make_graph(nodes, "test", inputs, outputs, tensors)
model = onnx.helper.make_model(
graph,
opset_imports=[onnx.helper.make_opsetid("", 18), onnx.helper.make_opsetid("ai.onnx.ml", 3)],
ir_version=11,
)

return model


if __name__ == "__main__":
model = create_exp_model()
onnx.save(model, "test_evil_weights.onnx")
Loading