|
23 | 23 |
|
24 | 24 |
|
25 | 25 | class ModelWorker: |
26 | | - def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16): |
| 26 | + def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16): |
27 | 27 | self.dtype = torch_dtype |
28 | 28 | start = time.perf_counter() |
29 | | - model = self.load_model(checkpoint, low_bit) |
30 | | - from ipex_llm.utils import BenchmarkWrapper |
31 | | - self.model = BenchmarkWrapper(model, do_print=True) |
| 29 | + if model_type == "audio": |
| 30 | + self.model = self.load_model(checkpoint, low_bit, "audio") |
| 31 | + else: |
| 32 | + model = self.load_model(checkpoint, low_bit) |
| 33 | + from ipex_llm.utils import BenchmarkWrapper |
| 34 | + self.model = BenchmarkWrapper(model, do_print=True) |
32 | 35 | end = time.perf_counter() |
33 | 36 | logger.info(f"Time to load weights: {end - start:.2f}s") |
34 | 37 | self.waiting_requests = asyncio.Queue() |
35 | 38 | self.streamer = {} |
36 | 39 | self.model_name = checkpoint |
37 | 40 |
|
38 | | - def load_model(self, model_path, low_bit='sym_int4'): |
39 | | - from ipex_llm.transformers import AutoModelForCausalLM, AutoModel |
40 | | - try: |
41 | | - model = AutoModelForCausalLM.from_pretrained(model_path, |
42 | | - load_in_low_bit=low_bit, |
43 | | - torch_dtype=self.dtype, |
44 | | - optimize_model=True, |
45 | | - trust_remote_code=True, |
46 | | - use_cache=True,) |
47 | | - except: |
48 | | - model = AutoModel.from_pretrained(model_path, |
49 | | - load_in_low_bit=low_bit, |
50 | | - torch_dtype=self.dtype, |
51 | | - optimize_model=True, |
52 | | - trust_remote_code=True, |
53 | | - use_cache=True,) |
| 41 | + def load_model(self, model_path, low_bit='sym_int4', model_type="normal"): |
| 42 | + if model_type == "audio": |
| 43 | + from ipex_llm.transformers import AutoModelForSpeechSeq2Seq |
| 44 | + model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, |
| 45 | + load_in_low_bit=low_bit, |
| 46 | + torch_dtype=self.dtype, |
| 47 | + optimize_model=True, |
| 48 | + trust_remote_code=True, |
| 49 | + use_cache=True) |
| 50 | + else: |
| 51 | + from ipex_llm.transformers import AutoModelForCausalLM, AutoModel |
| 52 | + try: |
| 53 | + model = AutoModelForCausalLM.from_pretrained(model_path, |
| 54 | + load_in_low_bit=low_bit, |
| 55 | + torch_dtype=self.dtype, |
| 56 | + optimize_model=True, |
| 57 | + trust_remote_code=True, |
| 58 | + use_cache=True,) |
| 59 | + except: |
| 60 | + model = AutoModel.from_pretrained(model_path, |
| 61 | + load_in_low_bit=low_bit, |
| 62 | + torch_dtype=self.dtype, |
| 63 | + optimize_model=True, |
| 64 | + trust_remote_code=True, |
| 65 | + use_cache=True,) |
54 | 66 | model = model.eval().to("xpu") |
55 | 67 | return model |
56 | 68 |
|
| 69 | + async def add_asr_request(self, processor): |
| 70 | + if self.waiting_requests.empty(): |
| 71 | + return |
| 72 | + tmp_result = await self.waiting_requests.get() |
| 73 | + request_id, request = tmp_result |
| 74 | + transcription_request = request.transcription_request |
| 75 | + forced_decoder_ids = processor.get_decoder_prompt_ids( |
| 76 | + language=transcription_request.language, task="transcribe") |
| 77 | + audio_path = transcription_request.file |
| 78 | + import librosa |
| 79 | + raw_speech, sampling_rate = librosa.load(audio_path, |
| 80 | + sr=processor.feature_extractor.sampling_rate) |
| 81 | + input_features = processor( |
| 82 | + raw_speech, |
| 83 | + sampling_rate=sampling_rate, |
| 84 | + return_tensors="pt", |
| 85 | + return_attention_mask=True, |
| 86 | + ).input_features.to('xpu') |
| 87 | + return input_features, forced_decoder_ids, request_id |
| 88 | + |
57 | 89 | async def add_request(self, tokenizer): |
58 | 90 | if self.waiting_requests.empty(): |
59 | 91 | return |
@@ -91,33 +123,41 @@ async def add_request(self, tokenizer): |
91 | 123 | return input_ids, parameters, request_id, inputs_embeds |
92 | 124 |
|
93 | 125 | @torch.no_grad() |
94 | | - async def process_step(self, tokenizer, result_dict): |
| 126 | + async def process_step(self, tokenizer, result_dict, processor=None): |
95 | 127 | if not self.waiting_requests.empty(): |
96 | | - input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) |
97 | | - self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) |
| 128 | + if processor is not None and "whisper" in self.model_name.lower(): |
| 129 | + input_features, decoder_ids, request_id = await self.add_asr_request(processor) |
| 130 | + self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) |
98 | 131 |
|
99 | | - def model_generate(): |
100 | | - generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} |
101 | | - if "codegeex" in self.model_name.lower(): |
102 | | - eos_token_id = [tokenizer.eos_token_id, |
103 | | - tokenizer.convert_tokens_to_ids("<|user|>"), |
104 | | - tokenizer.convert_tokens_to_ids("<|observation|>")] |
105 | | - generate_kwargs["eos_token_id"] = eos_token_id |
106 | | - elif "internlm-xcomposer2-vl-7b" in self.model_name.lower(): |
107 | | - eos_token_id = [ |
108 | | - tokenizer.eos_token_id, |
109 | | - tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] |
110 | | - ] |
111 | | - generate_kwargs["eos_token_id"] = eos_token_id |
112 | | - if input_ids is not None: |
113 | | - self.model.generate(input_ids, |
114 | | - streamer=self.streamer[request_id], **generate_kwargs) |
115 | | - elif inputs_embeds is not None: |
116 | | - self.model.generate(inputs_embeds=inputs_embeds, |
117 | | - streamer=self.streamer[request_id], **generate_kwargs) |
118 | | - torch.xpu.empty_cache() |
119 | | - torch.xpu.synchronize() |
| 132 | + def model_generate(): |
| 133 | + self.model.generate(input_features, |
| 134 | + streamer=self.streamer[request_id], |
| 135 | + forced_decoder_ids=decoder_ids) |
| 136 | + else: |
| 137 | + input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) |
| 138 | + self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) |
120 | 139 |
|
| 140 | + def model_generate(): |
| 141 | + generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} |
| 142 | + if "codegeex" in self.model_name.lower(): |
| 143 | + eos_token_id = [tokenizer.eos_token_id, |
| 144 | + tokenizer.convert_tokens_to_ids("<|user|>"), |
| 145 | + tokenizer.convert_tokens_to_ids("<|observation|>")] |
| 146 | + generate_kwargs["eos_token_id"] = eos_token_id |
| 147 | + elif "internlm-xcomposer2-vl-7b" in self.model_name.lower(): |
| 148 | + eos_token_id = [ |
| 149 | + tokenizer.eos_token_id, |
| 150 | + tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] |
| 151 | + ] |
| 152 | + generate_kwargs["eos_token_id"] = eos_token_id |
| 153 | + if input_ids is not None: |
| 154 | + self.model.generate(input_ids, |
| 155 | + streamer=self.streamer[request_id], **generate_kwargs) |
| 156 | + elif inputs_embeds is not None: |
| 157 | + self.model.generate(inputs_embeds=inputs_embeds, |
| 158 | + streamer=self.streamer[request_id], **generate_kwargs) |
| 159 | + torch.xpu.empty_cache() |
| 160 | + torch.xpu.synchronize() |
121 | 161 | from threading import Thread |
122 | 162 | t1 = Thread(target=model_generate) |
123 | 163 | t1.start() |
0 commit comments