2020 PoolingParamsUpdate ,
2121 PoolingType ,
2222)
23- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
23+ from vllm .model_executor .layers .rotary_embedding import get_rope
2424from vllm .model_executor .layers .vocab_parallel_embedding import VocabParallelEmbedding
2525from vllm .model_executor .model_loader .weight_utils import default_weight_loader
2626from vllm .sequence import IntermediateTensors
@@ -62,19 +62,6 @@ def forward(
6262 return embeddings
6363
6464
65- class ModernBertRotaryEmbedding (RotaryEmbedding ):
66- def __init__ (self , config : ModernBertConfig , head_size : int , dim : int , base : float ):
67- super ().__init__ (
68- head_size = head_size ,
69- rotary_dim = dim ,
70- max_position_embeddings = config .max_position_embeddings ,
71- base = base ,
72- is_neox_style = True ,
73- dtype = torch .float16 ,
74- )
75- self .config = config
76-
77-
7865class ModernBertAttention (nn .Module ):
7966 def __init__ (self , config : ModernBertConfig , layer_id : int | None = None ):
8067 super ().__init__ ()
@@ -95,19 +82,33 @@ def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
9582 bias = config .attention_bias ,
9683 )
9784
98- sliding_window = None
99- if layer_id % config .global_attn_every_n_layers != 0 :
100- sliding_window = config .local_attention // 2
101- rope_theta = (
102- config .local_rope_theta
103- if config .local_rope_theta is not None
104- else config .global_rope_theta
105- )
85+ if layer_types := getattr (config , "layer_types" , None ):
86+ # Transformers v5
87+ layer_type = layer_types [layer_id ]
88+ rope_parameters = config .rope_parameters [layer_type ]
89+ sliding_window : int | None = None
90+ if layer_type == "sliding_attention" :
91+ sliding_window = config .local_attention // 2
10692 else :
107- rope_theta = config .global_rope_theta
108-
109- self .rotary_emb = ModernBertRotaryEmbedding (
110- config = config , head_size = self .head_dim , dim = self .head_dim , base = rope_theta
93+ # Transformers v4
94+ sliding_window = None
95+ if layer_id % config .global_attn_every_n_layers != 0 :
96+ sliding_window = config .local_attention // 2
97+ rope_theta = (
98+ config .local_rope_theta
99+ if config .local_rope_theta is not None
100+ else config .global_rope_theta
101+ )
102+ else :
103+ rope_theta = config .global_rope_theta
104+ rope_parameters = {"rope_type" : "default" , "rope_theta" : rope_theta }
105+
106+ self .rotary_emb = get_rope (
107+ head_size = self .head_dim ,
108+ rotary_dim = self .head_dim ,
109+ max_position = config .max_position_embeddings ,
110+ rope_parameters = rope_parameters ,
111+ dtype = torch .float16 ,
111112 )
112113 self .attn = EncoderOnlyAttention (
113114 self .num_heads ,
0 commit comments