Skip to content

Commit a67c9f3

Browse files
authored
Create test_whisper_evaluator.py
1 parent 74a9882 commit a67c9f3

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Copyright (c) 2024 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
from pathlib import Path
17+
from unittest.mock import MagicMock
18+
19+
import pytest
20+
from accuracy_checker.evaluators.custom_evaluators.whisper_evaluator import (
21+
GenAIWhisperPipeline, OptimumWhisperPipeline, HFWhisperPipeline,
22+
WhisperEvaluator, normalize_transcription)
23+
from datasets import load_dataset
24+
from optimum.exporters.openvino.convert import export_tokenizer
25+
from optimum.intel.openvino import OVModelForSpeechSeq2Seq
26+
from transformers import AutoTokenizer,AutoProcessor
27+
28+
29+
def export_model(model_id, output_dir):
30+
tokenizer = AutoTokenizer.from_pretrained(model_id)
31+
processor = AutoProcessor.from_pretrained(model_id)
32+
base_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id)
33+
34+
base_model.save_pretrained(output_dir)
35+
tokenizer.save_pretrained(output_dir)
36+
processor.save_pretrained(output_dir)
37+
export_tokenizer(tokenizer, output_dir)
38+
39+
model_name = "openai/whisper-tiny"
40+
model_dir = Path("/tmp/whisper-tiny")
41+
42+
# Export the model
43+
export_model(model_name, model_dir)
44+
45+
# Load a single sample from the dataset
46+
dataset = load_dataset("openslr/librispeech_asr", "clean", split="validation", streaming=True, trust_remote_code=True)
47+
sample = next(iter(dataset))
48+
ground_truth = sample["text"]
49+
input_data = [sample["audio"]["array"]]
50+
input_meta = [{"sample_rate": sample["audio"]["sampling_rate"]}]
51+
identifiers = [sample["id"]]
52+
# print(ground_truth)
53+
54+
class TestWhisperEvaluator:
55+
def test_hf_whisper_pipeline(self):
56+
config = {"model_id": model_name}
57+
pipeline = HFWhisperPipeline(config)
58+
evaluator = WhisperEvaluator(None, pipeline, None)
59+
60+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
61+
assert isinstance(result, str)
62+
# print(result)
63+
64+
def test_genai_whisper_pipeline(self):
65+
config = {"_models": [model_dir], "_device": "CPU"}
66+
pipeline = GenAIWhisperPipeline(config)
67+
evaluator = WhisperEvaluator(None, pipeline, None)
68+
69+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
70+
assert isinstance(result, str)
71+
# print(result)
72+
73+
def test_optimum_whisper_pipeline(self):
74+
config = {"_models": [model_dir], "_device": "CPU"}
75+
pipeline = OptimumWhisperPipeline(config)
76+
evaluator = WhisperEvaluator(None, pipeline, None)
77+
78+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
79+
assert isinstance(result, str)
80+
# print(result)
81+
82+
83+
def test_normalize_transcription():
84+
engine = MagicMock()
85+
engine.number_to_words.side_effect = lambda x: "one" if x == "1" else x
86+
text = "This is a test 1"
87+
result = normalize_transcription(engine, text)
88+
assert result == "THIS IS A TEST ONE"

0 commit comments

Comments
 (0)