Skip to content

Commit ba12a89

Browse files
committed
Formatting fix
1 parent c57b290 commit ba12a89

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,11 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
569569
)
570570
# long factors is the first half, short factors is the second half
571571
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
572-
short_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2)) # type: ignore
572+
short_factors = T.Buffer(
573+
(rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2)
574+
)
573575

574-
if(seq_len > original_max_position_embeddings):
576+
if seq_len > original_max_position_embeddings:
575577
for iters in T.grid(seq_len, fused_heads, head_dim):
576578
with T.block("llama_fused_rope"):
577579
s, h, d = T.axis.remap("SSS", iters)

0 commit comments

Comments
 (0)