|
25 | 25 | """Inference-only Qwen3VL model compatible with HuggingFace weights.""" |
26 | 26 |
|
27 | 27 | from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence |
28 | | -from functools import partial |
| 28 | +from functools import lru_cache, partial |
29 | 29 | from itertools import islice |
30 | 30 | from typing import Any |
31 | 31 |
|
@@ -416,30 +416,41 @@ def dtype(self) -> torch.dtype: |
416 | 416 | def device(self) -> torch.device: |
417 | 417 | return self.patch_embed.proj.weight.device |
418 | 418 |
|
| 419 | + @staticmethod |
| 420 | + @lru_cache(maxsize=1024) |
| 421 | + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: |
| 422 | + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) |
| 423 | + h_div = h // spatial_merge_size |
| 424 | + w_div = w // spatial_merge_size |
| 425 | + hpos_ids = hpos_ids.reshape( |
| 426 | + h_div, |
| 427 | + spatial_merge_size, |
| 428 | + w_div, |
| 429 | + spatial_merge_size, |
| 430 | + ) |
| 431 | + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) |
| 432 | + hpos_ids = hpos_ids.flatten() |
| 433 | + |
| 434 | + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) |
| 435 | + wpos_ids = wpos_ids.reshape( |
| 436 | + h_div, |
| 437 | + spatial_merge_size, |
| 438 | + w_div, |
| 439 | + spatial_merge_size, |
| 440 | + ) |
| 441 | + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) |
| 442 | + wpos_ids = wpos_ids.flatten() |
| 443 | + |
| 444 | + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) |
| 445 | + |
419 | 446 | def rot_pos_emb(self, grid_thw: list[list[int]]): |
420 | | - pos_ids = [] |
421 | 447 | max_grid_size = max(max(h, w) for _, h, w in grid_thw) |
422 | | - for t, h, w in grid_thw: |
423 | | - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
424 | | - hpos_ids = hpos_ids.reshape( |
425 | | - h // self.spatial_merge_size, |
426 | | - self.spatial_merge_size, |
427 | | - w // self.spatial_merge_size, |
428 | | - self.spatial_merge_size, |
429 | | - ) |
430 | | - hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
431 | | - hpos_ids = hpos_ids.flatten() |
432 | | - |
433 | | - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
434 | | - wpos_ids = wpos_ids.reshape( |
435 | | - h // self.spatial_merge_size, |
436 | | - self.spatial_merge_size, |
437 | | - w // self.spatial_merge_size, |
438 | | - self.spatial_merge_size, |
439 | | - ) |
440 | | - wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
441 | | - wpos_ids = wpos_ids.flatten() |
442 | | - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| 448 | + pos_ids = [ |
| 449 | + self.rot_pos_ids(h, w, self.spatial_merge_size) |
| 450 | + if t == 1 |
| 451 | + else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) |
| 452 | + for t, h, w in grid_thw |
| 453 | + ] |
443 | 454 | pos_ids = torch.cat(pos_ids, dim=0) |
444 | 455 | rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
445 | 456 | rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
|
0 commit comments