Skip to content

Commit 5c4ed00

Browse files
authored
Add lightweight-serving whisper asr example (#11847)
* add asr init * update for pp * update style * update readme * update reamde
1 parent a8e2573 commit 5c4ed00

File tree

6 files changed

+177
-54
lines changed

6 files changed

+177
-54
lines changed

python/llm/example/GPU/Lightweight-Serving/README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
2222
# for internlm-xcomposer2-vl-7b
2323
pip install transformers==4.31.0
2424
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
25+
26+
# for whisper-large-v3
27+
pip install transformers==4.36.2
28+
pip install datasets soundfile librosa # required by audio processing
2529
```
2630

2731
#### 1.2 Installation on Windows
@@ -35,6 +39,14 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
3539
pip install fastapi uvicorn openai
3640
pip install gradio # for gradio web UI
3741
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
42+
43+
# for internlm-xcomposer2-vl-7b
44+
pip install transformers==4.31.0
45+
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
46+
47+
# for whisper-large-v3
48+
pip install transformers==4.36.2
49+
pip install datasets soundfile librosa # required by audio processing
3850
```
3951

4052
### 2. Configures OneAPI environment variables for Linux
@@ -180,7 +192,7 @@ curl http://localhost:8000/v1/chat/completions \
180192

181193
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
182194
```bash
183-
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
195+
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
184196
curl http://localhost:8000/v1/chat/completions \
185197
-H "Content-Type: application/json" \
186198
-d '{
@@ -219,6 +231,17 @@ curl http://localhost:8000/v1/completions \
219231
}'
220232
```
221233

234+
#### v1/audio/transcriptions
235+
236+
ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. And `whisper-large-v3` just can be used to transcription audio. The audio file_type should be supported by `librosa.load`.
237+
```bash
238+
curl http://localhost:8000/v1/audio/transcriptions \
239+
-H "Content-Type: multipart/form-data" \
240+
-F file="@/llm/test.mp3" \
241+
-F model="whisper-large-v3" \
242+
-F languag="zh"
243+
```
244+
222245
### 6. Benchmark with wrk
223246

224247
Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details

python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,19 @@ async def main():
3939
model_path = args.repo_id_or_model_path
4040
low_bit = args.low_bit
4141

42-
local_model = ModelWorker(model_path, low_bit)
43-
# Load tokenizer
44-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
45-
if tokenizer.pad_token is None:
46-
tokenizer.pad_token = tokenizer.eos_token
47-
myapp = FastApp(local_model, tokenizer)
42+
processor = None
43+
if "whisper" not in model_path.lower():
44+
local_model = ModelWorker(model_path, low_bit)
45+
# Load tokenizer
46+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
47+
if tokenizer.pad_token is None:
48+
tokenizer.pad_token = tokenizer.eos_token
49+
else:
50+
local_model = ModelWorker(model_path, low_bit, "audio", torch_dtype=torch.float32)
51+
from transformers import WhisperProcessor
52+
processor = WhisperProcessor.from_pretrained(model_path)
53+
tokenizer = processor.tokenizer
54+
myapp = FastApp(local_model, tokenizer, processor)
4855
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
4956
server = uvicorn.Server(config)
5057
await server.serve()

python/llm/src/ipex_llm/serving/fastapi/api_server.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from typing import List, Optional, Union, Dict
2828
from fastapi.middleware.cors import CORSMiddleware
2929
from .tgi_protocol import Parameters
30+
from typing_extensions import Literal
31+
from fastapi import File, UploadFile, Form
3032
from .openai_protocol import (
3133
ChatCompletionResponseStreamChoice,
3234
ChatCompletionStreamResponse,
@@ -38,6 +40,8 @@
3840
CompletionResponse,
3941
CompletionResponseStreamChoice,
4042
CompletionStreamResponse,
43+
TranscriptionRequest,
44+
TranscriptionResponse,
4145
)
4246

4347
result_dict: Dict[str, str] = {}
@@ -50,6 +54,7 @@ class InputsRequest(BaseModel):
5054
image_list: Optional[list] = None
5155
stream: Optional[bool] = False
5256
req_type: str = 'completion'
57+
transcription_request: Optional[TranscriptionRequest] = None
5358

5459

5560
class ChatCompletionRequest(BaseModel):
@@ -92,20 +97,27 @@ class CompletionRequest(BaseModel):
9297

9398
global tokenizer
9499
global local_model
100+
global processor
95101

96102

97103
class FastApp():
98-
def __init__(self, model, mytokenizer):
104+
def __init__(self, model, mytokenizer, myprocessor=None):
99105
global tokenizer
100106
global local_model
107+
global processor
101108
local_model = model
102109
tokenizer = mytokenizer
110+
processor = myprocessor
103111
self.app = app
104112

105113

106114
def get_queue_next_token(delta_text_queue):
107115
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
108116
delta_text = delta_text_queue.text_queue.get(timeout=timeout)
117+
if "whisper" in local_model.model_name.lower():
118+
if delta_text is not None and "<|" in delta_text and "|>" in delta_text:
119+
import re
120+
delta_text = re.sub(r'<\|.*?\|>', '', delta_text)
109121
if delta_text is None:
110122
remain = 0
111123
else:
@@ -385,6 +397,32 @@ async def create_completion(request: CompletionRequest):
385397
return result
386398

387399

400+
@app.post("/v1/audio/transcriptions")
401+
async def transcriptions(
402+
file: UploadFile=File(...),
403+
model: Optional[str]=Form("default_model"),
404+
language: Optional[str]=Form("zh"),
405+
prompt: Optional[str]=Form(None),
406+
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]]=Form(None),
407+
temperature: Optional[float]=Form(None),
408+
timestamp_granularities: Optional[List[Literal["word", "segment"]]]=Form(None)
409+
):
410+
file_path = "./" + file.filename
411+
if not os.path.exists(file_path):
412+
with open(file_path, "wb") as f:
413+
f.write(await file.read())
414+
inputs_request = InputsRequest(
415+
inputs="transcriptions",
416+
parameters=None,
417+
stream=False,
418+
req_type="completion",
419+
transcription_request=TranscriptionRequest(file=file_path, model=model, language=language)
420+
)
421+
request_id, result = await generate(inputs_request)
422+
rsp = TranscriptionResponse(text=result)
423+
return rsp
424+
425+
388426
@app.on_event("startup")
389427
async def startup_event():
390428
asyncio.create_task(process_requests(local_model, result_dict))
@@ -393,4 +431,4 @@ async def startup_event():
393431
async def process_requests(local_model, result_dict):
394432
while True:
395433
await asyncio.sleep(0)
396-
await local_model.process_step(tokenizer, result_dict)
434+
await local_model.process_step(tokenizer, result_dict, processor)

python/llm/src/ipex_llm/serving/fastapi/model_worker.py

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,69 @@
2323

2424

2525
class ModelWorker:
26-
def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
26+
def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16):
2727
self.dtype = torch_dtype
2828
start = time.perf_counter()
29-
model = self.load_model(checkpoint, low_bit)
30-
from ipex_llm.utils import BenchmarkWrapper
31-
self.model = BenchmarkWrapper(model, do_print=True)
29+
if model_type == "audio":
30+
self.model = self.load_model(checkpoint, low_bit, "audio")
31+
else:
32+
model = self.load_model(checkpoint, low_bit)
33+
from ipex_llm.utils import BenchmarkWrapper
34+
self.model = BenchmarkWrapper(model, do_print=True)
3235
end = time.perf_counter()
3336
logger.info(f"Time to load weights: {end - start:.2f}s")
3437
self.waiting_requests = asyncio.Queue()
3538
self.streamer = {}
3639
self.model_name = checkpoint
3740

38-
def load_model(self, model_path, low_bit='sym_int4'):
39-
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
40-
try:
41-
model = AutoModelForCausalLM.from_pretrained(model_path,
42-
load_in_low_bit=low_bit,
43-
torch_dtype=self.dtype,
44-
optimize_model=True,
45-
trust_remote_code=True,
46-
use_cache=True,)
47-
except:
48-
model = AutoModel.from_pretrained(model_path,
49-
load_in_low_bit=low_bit,
50-
torch_dtype=self.dtype,
51-
optimize_model=True,
52-
trust_remote_code=True,
53-
use_cache=True,)
41+
def load_model(self, model_path, low_bit='sym_int4', model_type="normal"):
42+
if model_type == "audio":
43+
from ipex_llm.transformers import AutoModelForSpeechSeq2Seq
44+
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path,
45+
load_in_low_bit=low_bit,
46+
torch_dtype=self.dtype,
47+
optimize_model=True,
48+
trust_remote_code=True,
49+
use_cache=True)
50+
else:
51+
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
52+
try:
53+
model = AutoModelForCausalLM.from_pretrained(model_path,
54+
load_in_low_bit=low_bit,
55+
torch_dtype=self.dtype,
56+
optimize_model=True,
57+
trust_remote_code=True,
58+
use_cache=True,)
59+
except:
60+
model = AutoModel.from_pretrained(model_path,
61+
load_in_low_bit=low_bit,
62+
torch_dtype=self.dtype,
63+
optimize_model=True,
64+
trust_remote_code=True,
65+
use_cache=True,)
5466
model = model.eval().to("xpu")
5567
return model
5668

69+
async def add_asr_request(self, processor):
70+
if self.waiting_requests.empty():
71+
return
72+
tmp_result = await self.waiting_requests.get()
73+
request_id, request = tmp_result
74+
transcription_request = request.transcription_request
75+
forced_decoder_ids = processor.get_decoder_prompt_ids(
76+
language=transcription_request.language, task="transcribe")
77+
audio_path = transcription_request.file
78+
import librosa
79+
raw_speech, sampling_rate = librosa.load(audio_path,
80+
sr=processor.feature_extractor.sampling_rate)
81+
input_features = processor(
82+
raw_speech,
83+
sampling_rate=sampling_rate,
84+
return_tensors="pt",
85+
return_attention_mask=True,
86+
).input_features.to('xpu')
87+
return input_features, forced_decoder_ids, request_id
88+
5789
async def add_request(self, tokenizer):
5890
if self.waiting_requests.empty():
5991
return
@@ -91,33 +123,41 @@ async def add_request(self, tokenizer):
91123
return input_ids, parameters, request_id, inputs_embeds
92124

93125
@torch.no_grad()
94-
async def process_step(self, tokenizer, result_dict):
126+
async def process_step(self, tokenizer, result_dict, processor=None):
95127
if not self.waiting_requests.empty():
96-
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
97-
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
128+
if processor is not None and "whisper" in self.model_name.lower():
129+
input_features, decoder_ids, request_id = await self.add_asr_request(processor)
130+
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
98131

99-
def model_generate():
100-
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
101-
if "codegeex" in self.model_name.lower():
102-
eos_token_id = [tokenizer.eos_token_id,
103-
tokenizer.convert_tokens_to_ids("<|user|>"),
104-
tokenizer.convert_tokens_to_ids("<|observation|>")]
105-
generate_kwargs["eos_token_id"] = eos_token_id
106-
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
107-
eos_token_id = [
108-
tokenizer.eos_token_id,
109-
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
110-
]
111-
generate_kwargs["eos_token_id"] = eos_token_id
112-
if input_ids is not None:
113-
self.model.generate(input_ids,
114-
streamer=self.streamer[request_id], **generate_kwargs)
115-
elif inputs_embeds is not None:
116-
self.model.generate(inputs_embeds=inputs_embeds,
117-
streamer=self.streamer[request_id], **generate_kwargs)
118-
torch.xpu.empty_cache()
119-
torch.xpu.synchronize()
132+
def model_generate():
133+
self.model.generate(input_features,
134+
streamer=self.streamer[request_id],
135+
forced_decoder_ids=decoder_ids)
136+
else:
137+
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
138+
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
120139

140+
def model_generate():
141+
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
142+
if "codegeex" in self.model_name.lower():
143+
eos_token_id = [tokenizer.eos_token_id,
144+
tokenizer.convert_tokens_to_ids("<|user|>"),
145+
tokenizer.convert_tokens_to_ids("<|observation|>")]
146+
generate_kwargs["eos_token_id"] = eos_token_id
147+
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
148+
eos_token_id = [
149+
tokenizer.eos_token_id,
150+
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
151+
]
152+
generate_kwargs["eos_token_id"] = eos_token_id
153+
if input_ids is not None:
154+
self.model.generate(input_ids,
155+
streamer=self.streamer[request_id], **generate_kwargs)
156+
elif inputs_embeds is not None:
157+
self.model.generate(inputs_embeds=inputs_embeds,
158+
streamer=self.streamer[request_id], **generate_kwargs)
159+
torch.xpu.empty_cache()
160+
torch.xpu.synchronize()
121161
from threading import Thread
122162
t1 = Thread(target=model_generate)
123163
t1.start()

python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,28 @@
2424
from pydantic import BaseModel, ConfigDict, Field, model_validator
2525
from typing_extensions import Annotated
2626
from ipex_llm.utils.common import invalidInputError
27+
from typing_extensions import Literal
2728

2829

2930
# from vllm.sampling_params import SamplingParams
3031
def random_uuid() -> str:
3132
return str(uuid.uuid4().hex)
3233

3334

35+
class TranscriptionRequest(BaseModel):
36+
file: str = None
37+
model: Optional[str] = "default_model"
38+
language: Optional[str] = "zh"
39+
prompt: Optional[str] = None
40+
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = None
41+
temperature: Optional[float] = None
42+
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None
43+
44+
45+
class TranscriptionResponse(BaseModel):
46+
text: str
47+
48+
3449
class OpenAIBaseModel(BaseModel):
3550
# OpenAI API does not allow extra fields
3651
model_config = ConfigDict(extra="forbid")

python/llm/src/ipex_llm/transformers/pipeline_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ async def stream_output(self, cur_batch, tokenizer, next_ids):
800800
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
801801
await asyncio.gather(*_stream_tasks)
802802

803-
async def process_step(self, tokenizer, result_dict):
803+
async def process_step(self, tokenizer, result_dict, processor=None):
804804
cur_batch = None
805805
torch.xpu.synchronize(self.device)
806806
if self.rank == 0:

0 commit comments

Comments
 (0)