Skip to content

Commit f8ab833

Browse files
authored
support and optimize janus pro (#12813)
1 parent bd815a4 commit f8ab833

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

python/llm/src/ipex_llm/transformers/convert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def _optimize_pre(model, qtype=None):
10661066
from ipex_llm.transformers.models.baichuan_m1 import pre_register_inv_freq
10671067
model.apply(pre_register_inv_freq)
10681068
elif model.config.model_type == "multi_modality":
1069-
pass
1069+
_optimize_pre(model.language_model)
10701070

10711071
return model
10721072

@@ -2012,8 +2012,10 @@ def _optimize_post(model):
20122012
# vision
20132013
vpm_modeling_module_name = model.vision_model.vision_tower.__class__.__module__
20142014
vpm_module = importlib.import_module(vpm_modeling_module_name)
2015-
20162015
from ipex_llm.transformers.models.janus import vision_attention_forward
20172016
convert_forward(model.vision_model, vpm_module.Attention, vision_attention_forward)
20182017

2018+
# llm
2019+
_optimize_post(model.language_model)
2020+
20192021
return model
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# This file is adapted from
17+
# https://github.com/deepseek-ai/Janus/blob/main/janus/models/siglip_vit.py
18+
19+
import torch
20+
21+
from ipex_llm.transformers.models.common import scaled_dot_product_attention
22+
23+
24+
def vision_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
25+
B, N, C = x.shape
26+
qkv = (
27+
self.qkv(x)
28+
.reshape(B, N, 3, self.num_heads, self.head_dim)
29+
.permute(2, 0, 3, 1, 4)
30+
)
31+
q, k, v = qkv.unbind(0)
32+
q, k = self.q_norm(q), self.k_norm(k)
33+
34+
if self.fused_attn:
35+
# ipex-llm opt: sdpa
36+
x = scaled_dot_product_attention(
37+
q, k.contiguous(), v.contiguous(), None, False
38+
)
39+
else:
40+
q = q * self.scale
41+
attn = q @ k.transpose(-2, -1)
42+
attn = attn.softmax(dim=-1)
43+
attn = self.attn_drop(attn)
44+
x = attn @ v
45+
46+
x = x.transpose(1, 2).reshape(B, N, C)
47+
x = self.proj(x)
48+
x = self.proj_drop(x)
49+
return x

python/llm/src/ipex_llm/transformers/models/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor,
8686
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
8787
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
8888
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
89-
elif linear.qtype in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
89+
elif linear.weight.dtype != torch.uint8: # unquantized
9090
return False
9191
else:
9292
device_name = get_xpu_device_name(x.device)

0 commit comments

Comments
 (0)