@@ -117,7 +117,7 @@ def merge_state_in_place(
117117 s_other : torch.Tensor
118118 The other logsumexp value to be merged, expected to be a float32 tensor,
119119 shape: ``(seq_len, num_heads)``.
120-
120+
121121 Example
122122 -------
123123 >>> import torch
@@ -135,7 +135,7 @@ def merge_state_in_place(
135135
136136
137137def merge_states (v : torch .Tensor , s : torch .Tensor ):
138- r"""Merge multiple attention states (v, s).
138+ r"""Merge multiple attention states (v, s).
139139
140140 Parameters
141141 ----------
@@ -154,7 +154,7 @@ def merge_states(v: torch.Tensor, s: torch.Tensor):
154154 S : torch.Tensor
155155 The logsumexp value from the merged KV-segments, shape:
156156 ``[seq_len, num_heads]``.
157-
157+
158158 Example
159159 -------
160160 >>> import torch
@@ -229,7 +229,7 @@ def batch_decode_with_shared_prefix_padded_kv_cache(
229229 -------
230230 V : torch.Tensor
231231 The attention output, shape: ``[batch_size, num_heads, head_dim]``
232-
232+
233233 Example
234234 -------
235235 >>> import torch
@@ -312,7 +312,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
312312 ... )
313313 >>> batch_size = 7
314314 >>> shared_prefix_len = 8192
315- >>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
315+ >>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
316316 >>> unique_kv_page_indptr = torch.tensor(
317317 ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
318318 ... )
@@ -355,7 +355,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
355355 ... # compute batch decode attention, reuse auxiliary data structures for all layers
356356 ... o = wrapper.forward(q, k_shared, v_shared, unique_kv_data)
357357 ... outputs.append(o)
358- ...
358+ ...
359359 >>> # clear auxiliary data structures
360360 >>> wrapper.end_forward()
361361 >>> outputs[0].shape
@@ -547,7 +547,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
547547 >>> qo_indptr = torch.tensor(
548548 ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
549549 ... )
550- >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
550+ >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
551551 >>> paged_kv_indptr = torch.tensor(
552552 ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
553553 ... )
@@ -590,7 +590,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
590590 ... q, k_shared, v_shared, kv_data, causal=True
591591 ... )
592592 ... outputs.append(o)
593- ...
593+ ...
594594 s[0].shape>>> # clear auxiliary data structures
595595 >>> prefill_wrapper.end_forward()
596596 >>> outputs[0].shape
0 commit comments