Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class ONNX_FRONTEND_API GraphIterator : ::ov::RuntimeAttribute {
/// If there are no domain found returns -1
virtual int64_t get_opset_version(const std::string& domain) const = 0;

/// \brief Retrieves metadata associated with the graph.
virtual std::map<std::string, std::string> get_metadata() const = 0;

/// \brief Destructor
virtual ~GraphIterator();
};
Expand Down
14 changes: 14 additions & 0 deletions src/frontends/onnx/frontend/src/core/graph_iterator_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,20 @@ std::int64_t GraphIteratorProto::get_opset_version(const std::string& domain) co
return -1;
}

std::map<std::string, std::string> GraphIteratorProto::get_metadata() const {
std::map<std::string, std::string> metadata;

if (!m_model) {
return metadata;
}

const auto& model_metadata = m_model->metadata_props();
for (const auto& prop : model_metadata) {
metadata.emplace(prop.key(), prop.value());
}
return metadata;
}

namespace detail {
namespace {
enum Field {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class GraphIteratorProto : public ov::frontend::onnx::GraphIterator {

std::int64_t get_opset_version(const std::string& domain) const override;

std::map<std::string, std::string> get_metadata() const override;

std::string get_model_dir() const {
return *m_model_dir;
}
Expand Down
11 changes: 11 additions & 0 deletions src/frontends/onnx/frontend/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,10 @@ class InputModel::InputModelONNXImpl {
void extract_subgraph(const std::vector<ov::frontend::Place::Ptr>& inputs,
const std::vector<ov::frontend::Place::Ptr>& outputs);

std::map<std::string, std::string> get_metadata() const {
return m_metadata;
}

std::shared_ptr<TelemetryExtension> get_telemetry_extension() const {
return m_telemetry;
}
Expand Down Expand Up @@ -633,6 +637,7 @@ class InputModel::InputModelONNXImpl {
std::shared_ptr<GraphIterator> m_graph_iterator;
const ov::frontend::InputModel& m_input_model;
std::vector<std::shared_ptr<ov::frontend::onnx::unify::InputModel>> m_subgraphs;
std::map<std::string, std::string> m_metadata;
std::shared_ptr<TelemetryExtension> m_telemetry;
bool m_enable_mmap;

Expand Down Expand Up @@ -753,6 +758,8 @@ void InputModel::InputModelONNXImpl::load_model() {
m_telemetry->send_event("op_count", "onnx_" + op.first, static_cast<int>(op.second));
}
}

m_metadata = m_graph_iterator->get_metadata();
}

InputModel::InputModelONNXImpl::InputModelONNXImpl(const GraphIterator::Ptr& graph_iterator,
Expand Down Expand Up @@ -947,6 +954,10 @@ void InputModel::extract_subgraph(const std::vector<ov::frontend::Place::Ptr>& i
_impl->extract_subgraph(inputs, outputs);
}

std::map<std::string, std::string> InputModel::get_metadata() const {
return _impl->get_metadata();
}

std::shared_ptr<TelemetryExtension> InputModel::get_telemetry_extension() {
return _impl->get_telemetry_extension();
}
Expand Down
1 change: 1 addition & 0 deletions src/frontends/onnx/frontend/src/input_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class InputModel : public ov::frontend::InputModel {
void extract_subgraph(const std::vector<ov::frontend::Place::Ptr>& inputs,
const std::vector<ov::frontend::Place::Ptr>& outputs) override;

std::map<std::string, std::string> get_metadata() const;
std::shared_ptr<TelemetryExtension> get_telemetry_extension();

bool is_enabled_mmap() const;
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/onnx/frontend/src/translate_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,10 @@ void TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& inpu

auto model_name = "onnx_Frontend_IR";
ov_model = std::make_shared<ov::Model>(results, m_parameters, model_name);

const auto& metadata = model_onnx->get_metadata();
const std::string framework_section = "framework";
for (const auto& pair : metadata) {
ov_model->set_rt_info(pair.second, framework_section, pair.first);
}
}
4 changes: 4 additions & 0 deletions src/frontends/onnx/tests/graph_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ TEST_P(FrontEndLoadFromTest, testLoadUsingSimpleGraphIterator) {
return 1;
}

std::map<std::string, std::string> get_metadata() const override {
return {};
}

~SimpleIterator() override {};
};

Expand Down
Loading