Skip to content

Commit 6275116

Browse files
authored
Merge branch 'master' into py3.8
2 parents 633e9f2 + 5182cb4 commit 6275116

File tree

10 files changed

+309
-5
lines changed

10 files changed

+309
-5
lines changed

.github/workflows/cpp_gapi-demos.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
rm -rf cache/opencv/.git/ # Minimize cache
4141
mkdir cache/opencv/build
4242
cd cache/opencv/build
43-
cmake -DCMAKE_BUILD_TYPE=Release -DWITH_INF_ENGINE=y -DOpenVINO_DIR=$GITHUB_WORKSPACE/ov/runtime/cmake/ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_LINKER_LAUNCHER=ccache -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_C_LINKER_LAUNCHER=ccache -DBUILD_TESTS=y -DVIDEOIO_ENABLE_PLUGINS=y -DBUILD_PERF_TESTS=n -DBUILD_EXAMPLES=n -DBUILD_opencv_apps=y -DWITH_OPENCL=n -DWITH_OPENCLAMDBLAS=n -DWITH_GSTREAMER=n -DWITH_V4L=ON -DWITH_LIBV4L=ON -DWITH_OPENCLAMDFFT=n -DWITH_VA=n -DWITH_VA_INTEL=n -DWITH_PROTOBUF=n -DBUILD_PROTOBUF=n -DBUILD_JAVA=n -DBUILD_opencv_java_bindings_generator=n -DBUILD_opencv_python2=n -DBUILD_opencv_python3=n -DWITH_IMGCODEC_HDR=y -DWITH_IMGCODEC_SUNRASTER=y -DWITH_IMGCODEC_PXM=y -DWITH_IMGCODEC_PFM=y -DWITH_PNG=y -DWITH_TIFF=n -DWITH_WEBP=n -DWITH_OPENJPEG=n -DWITH_JASPER=n -DWITH_OPENEXR=n -DBUILD_opencv_dnn=n -DBUILD_opencv_features2d=n -DBUILD_opencv_flann=n -DWITH_TBB=n -DBUILD_INFO_SKIP_EXTRA_MODULES=n -DBUILD_JASPER=n -DBUILD_PNG=n -DBUILD_OPENEXR=n -DBUILD_WEBP=n -DBUILD_ZLIB=n -DWITH_CUDA=n -DWITH_EIGEN=n -DWITH_GPHOTO2=n -DOPENCV_GAPI_GSTREAMER=n -DWITH_LAPACK=n -DWITH_MATLAB=n -DWITH_MFX=n -DWITH_QUIRC=n -DWITH_VTK=n -DINSTALL_PDB=n -DINSTALL_TESTS=n -DINSTALL_C_EXAMPLES=n -DINSTALL_PYTHON_EXAMPLES=n -DOPENCV_GENERATE_SETUPVARS=n -DWITH_1394=n -DWITH_FFMPEG=y -DWITH_GTK_2_X=y -DBUILD_JPEG=y -DWITH_IPP=y -DENABLE_CONFIG_VERIFICATION=y -DBUILD_LIST=core,gapi,highgui,imgcodecs,imgproc,videoio,video ..
43+
cmake -DCMAKE_BUILD_TYPE=Release -DWITH_INF_ENGINE=y -DOpenVINO_DIR=$GITHUB_WORKSPACE/ov/runtime/cmake/ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_LINKER_LAUNCHER=ccache -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_C_LINKER_LAUNCHER=ccache -DBUILD_TESTS=y -DVIDEOIO_ENABLE_PLUGINS=y -DBUILD_PERF_TESTS=n -DBUILD_EXAMPLES=n -DBUILD_opencv_apps=y -DWITH_OPENCL=n -DWITH_OPENCLAMDBLAS=n -DWITH_GSTREAMER=n -DWITH_V4L=ON -DWITH_LIBV4L=ON -DWITH_OPENCLAMDFFT=n -DWITH_VA=n -DWITH_VA_INTEL=n -DWITH_PROTOBUF=n -DBUILD_PROTOBUF=n -DBUILD_JAVA=n -DBUILD_opencv_java_bindings_generator=n -DBUILD_opencv_python2=n -DBUILD_opencv_python3=n -DWITH_IMGCODEC_HDR=y -DWITH_IMGCODEC_SUNRASTER=y -DWITH_IMGCODEC_PXM=y -DWITH_IMGCODEC_PFM=y -DWITH_PNG=y -DWITH_TIFF=n -DWITH_WEBP=n -DWITH_OPENJPEG=n -DWITH_JASPER=n -DWITH_OPENEXR=n -DBUILD_opencv_dnn=n -DBUILD_opencv_features2d=n -DBUILD_opencv_flann=n -DWITH_TBB=n -DBUILD_INFO_SKIP_EXTRA_MODULES=n -DBUILD_JASPER=n -DBUILD_PNG=n -DBUILD_OPENEXR=n -DBUILD_WEBP=n -DBUILD_ZLIB=n -DWITH_CUDA=n -DWITH_EIGEN=n -DWITH_GPHOTO2=n -DOPENCV_GAPI_GSTREAMER=n -DWITH_LAPACK=n -DWITH_MATLAB=n -DWITH_MFX=n -DWITH_QUIRC=n -DWITH_VTK=n -DINSTALL_PDB=n -DINSTALL_TESTS=n -DINSTALL_C_EXAMPLES=n -DINSTALL_PYTHON_EXAMPLES=n -DOPENCV_GENERATE_SETUPVARS=n -DWITH_1394=n -DWITH_FFMPEG=y -DWITH_GTK_2_X=y -DBUILD_JPEG=y -DWITH_IPP=y -DWITH_AVIF=n -DENABLE_CONFIG_VERIFICATION=y -DBUILD_LIST=core,gapi,highgui,imgcodecs,imgproc,videoio,video ..
4444
cmake --build . -j $((`nproc`*2+2))
4545
- name: build_demos.sh
4646
run: |

Jenkinsfile

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!groovy
2+
3+
properties([
4+
parameters([
5+
booleanParam(defaultValue: false,
6+
description: 'Cancel the rest of parallel stages if one of them fails and return status immediately',
7+
name: 'failFast'),
8+
booleanParam(defaultValue: true,
9+
description: 'Whether to propagate commit status to GitHub',
10+
name: 'propagateStatus'),
11+
string(defaultValue: '',
12+
description: 'Pipeline shared library version (branch/tag/commit). Determined automatically if empty',
13+
name: 'library_version'),
14+
string(defaultValue: '',
15+
description: 'Docker tag to take images with. Determined automatically if empty',
16+
name: 'docker_tag')
17+
])
18+
])
19+
20+
loadOpenVinoLibrary {
21+
entrypoint(this)
22+
}

demos/common/cpp/utils/src/config_factory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ModelConfig ConfigFactory::getUserConfig(const std::string& flags_d,
4848
if (flags_nthreads != 0)
4949
config.compiledModelConfig.emplace(ov::inference_num_threads.name(), flags_nthreads);
5050

51-
config.compiledModelConfig.emplace(ov::affinity.name(), ov::Affinity::NONE);
51+
config.compiledModelConfig.emplace(ov::hint::enable_cpu_pinning.name(), false);
5252

5353
ov::streams::Num nstreams =
5454
deviceNstreams.count(device) > 0 ? ov::streams::Num(deviceNstreams[device]) : ov::streams::AUTO;

demos/multi_channel_common/cpp/graph.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ static inline size_t roundUp(size_t enumerator, size_t denominator) {
2929

3030
static inline std::queue<ov::InferRequest> compile(std::shared_ptr<ov::Model>&& model, const std::string& modelPath,
3131
const std::string& device, size_t performanceHintNumRequests, ov::Core& core) {
32-
core.set_property("CPU", ov::affinity(ov::Affinity::NONE));
32+
core.set_property("CPU", ov::hint::enable_cpu_pinning(false));
3333
ov::CompiledModel compiled = core.compile_model(model, device, {
3434
{ov::hint::performance_mode(ov::hint::PerformanceMode::THROUGHPUT)},
3535
{ov::hint::num_requests(performanceHintNumRequests)}});

demos/security_barrier_camera_demo/cpp/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ int main(int argc, char* argv[]) {
718718
if (FLAGS_nthreads != 0) {
719719
core.set_property("CPU", ov::inference_num_threads(FLAGS_nthreads));
720720
}
721-
core.set_property("CPU", ov::affinity(ov::Affinity::NONE));
721+
core.set_property("CPU", ov::hint::enable_cpu_pinning(false));
722722
core.set_property("CPU", ov::streams::num((device_nstreams.count("CPU") > 0 ? ov::streams::Num(device_nstreams["CPU"]) : ov::streams::AUTO)));
723723

724724
device_nstreams["CPU"] = core.get_property("CPU", ov::streams::num);

demos/social_distance_demo/cpp/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ int main(int argc, char* argv[]) {
695695
if (FLAGS_nthreads != 0) {
696696
core.set_property("CPU", ov::inference_num_threads(FLAGS_nthreads));
697697
}
698-
core.set_property("CPU", ov::affinity(ov::Affinity::NONE));
698+
core.set_property("CPU", ov::hint::enable_cpu_pinning(false));
699699
core.set_property("CPU", ov::streams::num((deviceNStreams.count("CPU") > 0 ? ov::streams::Num(deviceNStreams["CPU"]) : ov::streams::AUTO)));
700700
deviceNStreams["CPU"] = core.get_property("CPU", ov::streams::num);
701701
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Copyright (c) 2024 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import re
17+
18+
from ...representation import CharacterRecognitionPrediction
19+
from ...utils import UnsupportedPackage, extract_image_representations
20+
from .base_custom_evaluator import BaseCustomEvaluator
21+
22+
try:
23+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
24+
except ImportError as import_err:
25+
AutoModelForSpeechSeq2Seq = UnsupportedPackage("transformers", import_err.msg)
26+
AutoProcessor = UnsupportedPackage("transformers", import_err.msg)
27+
28+
try:
29+
from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
30+
except ImportError as import_err:
31+
AutomaticSpeechRecognitionPipeline = UnsupportedPackage("transformers", import_err.msg)
32+
33+
try:
34+
import inflect
35+
except ImportError as import_err:
36+
inflect = UnsupportedPackage("inflect", import_err.msg)
37+
38+
39+
class WhisperEvaluator(BaseCustomEvaluator):
40+
VALID_PIPELINE_CLASSES = [
41+
"GenAIWhisperPipeline",
42+
"HFWhisperPipeline",
43+
"OptimumWhisperPipeline"
44+
]
45+
46+
def __init__(self, dataset_config, pipe, orig_config):
47+
super().__init__(dataset_config, None, orig_config)
48+
self.pipe = pipe
49+
if hasattr(self.pipe, "adapter"):
50+
self.adapter_type = self.pipe.adapter.__provider__
51+
52+
@classmethod
53+
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
54+
dataset_config = config["datasets"]
55+
pipeline_class_name = config["pipeline_class"]
56+
if 'device' in config['launchers'][0]:
57+
config["_device"] = config['launchers'][0]['device']
58+
59+
if pipeline_class_name not in cls.VALID_PIPELINE_CLASSES:
60+
raise ValueError(f"Invalid pipeline class name: {pipeline_class_name}. "
61+
f"Must be one of {cls.VALID_PIPELINE_CLASSES}")
62+
63+
pipeline_class = globals()[pipeline_class_name]
64+
pipe = pipeline_class(config)
65+
return cls(dataset_config, pipe, orig_config)
66+
67+
def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file):
68+
for batch_id, (batch_input_ids, batch_annotation, batch_inputs, batch_identifiers) in enumerate(self.dataset):
69+
batch_inputs = self.preprocessor.process(batch_inputs, batch_annotation)
70+
batch_inputs_extr, batch_meta = extract_image_representations(batch_inputs)
71+
72+
batch_raw_prediction, batch_prediction = self.pipe.predict(
73+
batch_identifiers, batch_inputs_extr, batch_meta
74+
)
75+
metrics_result = self._get_metrics_result(batch_input_ids, batch_annotation, batch_prediction,
76+
calculate_metrics)
77+
if output_callback:
78+
output_callback(batch_raw_prediction[0], metrics_result=metrics_result,
79+
element_identifiers=batch_identifiers, dataset_indices=batch_input_ids)
80+
self._update_progress(progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file)
81+
82+
def release(self):
83+
pass
84+
85+
86+
def normalize_transcription(engine, text):
87+
# Convert numbers to words
88+
tokens = (engine.number_to_words(token) if token.isdigit() else token for token in text.split())
89+
# Remove punctuation except for apostrophes that are in the middle of words
90+
text = re.sub(r"\b'\b|[^\w\s]", "", " ".join(tokens))
91+
# Remove leading, trailing, and multiple consecutive spaces, and convert to uppercase
92+
return " ".join(text.upper().split())
93+
94+
95+
class WhisperPipeline:
96+
def __init__(self, config):
97+
self.engine = inflect.engine()
98+
self.pipeline = self._initialize_pipeline(config)
99+
100+
def _initialize_pipeline(self, config):
101+
raise NotImplementedError
102+
103+
def _get_predictions(self, data, identifiers, input_meta):
104+
raise NotImplementedError
105+
106+
def predict(self, identifiers, input_data, input_meta, encoder_callback=None):
107+
predictions = []
108+
outputs = []
109+
for data in input_data:
110+
transcription = self._get_predictions(data, identifiers, input_meta)
111+
prediction_text = normalize_transcription(self.engine, transcription)
112+
predictions.append(prediction_text)
113+
outputs.append(CharacterRecognitionPrediction(identifiers[0], predictions[0]))
114+
return [], outputs
115+
116+
117+
class GenAIWhisperPipeline(WhisperPipeline):
118+
def _initialize_pipeline(self, config):
119+
try:
120+
import openvino_genai as ov_genai # pylint: disable=C0415
121+
except ImportError as import_error:
122+
UnsupportedPackage("openvino_genai", import_error.msg).raise_error(self.__class__.__name__)
123+
124+
model_dir = config.get("_models", [None])[0]
125+
device = config.get("_device", "CPU")
126+
pipeline = ov_genai.WhisperPipeline(str(model_dir), device=device)
127+
return pipeline
128+
129+
def _get_predictions(self, data, identifiers, input_meta):
130+
return self.pipeline.generate(data[0], return_timestamps=True).texts[0]
131+
132+
133+
class HFWhisperPipeline(WhisperPipeline):
134+
def _initialize_pipeline(self, config):
135+
try:
136+
import torch # pylint: disable=C0415
137+
except ImportError as import_error:
138+
UnsupportedPackage("torch", import_error.msg).raise_error(self.__class__.__name__)
139+
140+
model_id = config.get("model_id")
141+
device = "cpu"
142+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
143+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
144+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
145+
).to(device)
146+
147+
processor = AutoProcessor.from_pretrained(model_id)
148+
149+
pipeline = AutomaticSpeechRecognitionPipeline(
150+
model=model,
151+
tokenizer=processor.tokenizer,
152+
feature_extractor=processor.feature_extractor,
153+
torch_dtype=torch_dtype,
154+
device=device,
155+
)
156+
return pipeline
157+
158+
def _get_predictions(self, data, identifiers, input_meta):
159+
sampling_rate = input_meta[0].get("sample_rate")
160+
sample = {"path": identifiers[0], "array": data[0], "sampling_rate": sampling_rate}
161+
return self.pipeline(sample, return_timestamps=True)["text"]
162+
163+
164+
class OptimumWhisperPipeline(WhisperPipeline):
165+
def _initialize_pipeline(self, config):
166+
try:
167+
from optimum.intel.openvino import OVModelForSpeechSeq2Seq # pylint: disable=C0415
168+
except ImportError as import_error:
169+
UnsupportedPackage("optimum.intel.openvino", import_error.msg).raise_error(self.__class__.__name__)
170+
171+
device = config.get("_device", "CPU")
172+
model_dir = config.get("_models", [None])[0]
173+
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(str(model_dir)).to(device)
174+
ov_processor = AutoProcessor.from_pretrained(str(model_dir))
175+
176+
pipeline = AutomaticSpeechRecognitionPipeline(
177+
model=ov_model,
178+
tokenizer=ov_processor.tokenizer,
179+
feature_extractor=ov_processor.feature_extractor
180+
)
181+
return pipeline
182+
183+
def _get_predictions(self, data, identifiers, input_meta):
184+
sampling_rate = input_meta[0].get("sample_rate")
185+
sample = {"path": identifiers[0], "array": data[0], "sampling_rate": sampling_rate}
186+
return self.pipeline(sample, return_timestamps=True)["text"]

tools/accuracy_checker/requirements-extra.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ lmdb>=1.2.1
4848

4949
# pandas datasets support
5050
pandas>=1.1.5,<2.1
51+
52+
# word-based representations of numbers
53+
inflect>=7.4.0

tools/accuracy_checker/requirements-test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pytest-mock~=2.0
77
# will not include atomicwrites and thus will not work on Windows.
88
# So as a workaround, make the atomicwrites dependency unconditional.
99
atomicwrites
10+
datasets
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Copyright (c) 2024-2025 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import os
17+
from pathlib import Path
18+
from unittest.mock import MagicMock, patch
19+
20+
import pytest
21+
from accuracy_checker.evaluators.custom_evaluators.whisper_evaluator import (
22+
GenAIWhisperPipeline, HFWhisperPipeline, OptimumWhisperPipeline,
23+
WhisperEvaluator)
24+
from datasets import load_dataset
25+
26+
AutoProcessor = pytest.importorskip("transformers", reason="transformers is not available").AutoProcessor
27+
AutoTokenizer = pytest.importorskip("transformers", reason="transformers is not available").AutoTokenizer
28+
export_tokenizer = pytest.importorskip("optimum.exporters.openvino.convert", reason="optimum.exporters.openvino.convert is not available").export_tokenizer
29+
OVModelForSpeechSeq2Seq = pytest.importorskip("optimum.intel.openvino", reason="optimum.intel.openvino is not available").OVModelForSpeechSeq2Seq
30+
31+
32+
model_id = "openai/whisper-tiny"
33+
model_dir = Path("/tmp/whisper-tiny")
34+
35+
def setup_module(module):
36+
# Setup code here
37+
global input_data, input_meta, identifiers
38+
39+
# Load a single sample from the dataset
40+
dataset = load_dataset("openslr/librispeech_asr", "clean", split="validation", streaming=True, trust_remote_code=True)
41+
sample = next(iter(dataset))
42+
input_data = [sample["audio"]["array"]]
43+
input_meta = [{"sample_rate": sample["audio"]["sampling_rate"]}]
44+
identifiers = [sample["id"]]
45+
46+
def teardown_module(module):
47+
# Cleanup code here
48+
if model_dir.exists():
49+
for item in model_dir.iterdir():
50+
if item.is_file():
51+
item.unlink()
52+
model_dir.rmdir()
53+
54+
def test_optimum_convert_model_to_ir():
55+
tokenizer = AutoTokenizer.from_pretrained(model_id)
56+
processor = AutoProcessor.from_pretrained(model_id)
57+
base_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id)
58+
59+
model_dir.mkdir(parents=True, exist_ok=True)
60+
base_model.save_pretrained(model_dir)
61+
tokenizer.save_pretrained(model_dir)
62+
processor.save_pretrained(model_dir)
63+
export_tokenizer(tokenizer, model_dir)
64+
65+
assert base_model.__class__.__module__.startswith('optimum.intel.openvino')
66+
67+
class TestWhisperEvaluator:
68+
def test_hf_whisper_pipeline(self):
69+
config = {"model_id": model_id}
70+
pipeline = HFWhisperPipeline(config)
71+
evaluator = WhisperEvaluator(None, pipeline, None)
72+
73+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
74+
assert isinstance(result, str)
75+
76+
@pytest.mark.dependency(depends=["test_optimum_convert_model_to_ir"])
77+
def test_genai_whisper_pipeline(self):
78+
config = {"_models": [model_dir], "_device": "CPU"}
79+
pipeline = GenAIWhisperPipeline(config)
80+
evaluator = WhisperEvaluator(None, pipeline, None)
81+
82+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
83+
assert isinstance(result, str)
84+
85+
@pytest.mark.dependency(depends=["test_optimum_convert_model_to_ir"])
86+
def test_optimum_whisper_pipeline(self):
87+
config = {"_models": [model_dir], "_device": "CPU"}
88+
pipeline = OptimumWhisperPipeline(config)
89+
evaluator = WhisperEvaluator(None, pipeline, None)
90+
91+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
92+
assert isinstance(result, str)

0 commit comments

Comments
 (0)