Skip to content

Commit 40c76f2

Browse files
Formatting
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent b5c0063 commit 40c76f2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

include/torch-mlir/Dialect/Torch/IR/TorchOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,10 +1450,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14501450
let summary = "Generated op for `aten::flex_attention`";
14511451
let description = [{
14521452
FlexAttention operation with flexible block-sparse attention patterns.
1453-
1453+
14541454
Args:
14551455
query: Query tensor [B, H, M, K]
1456-
key: Key tensor [B, H, N, K]
1456+
key: Key tensor [B, H, N, K]
14571457
value: Value tensor [B, H, N, Ev]
14581458
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
14591459
enable_gqa: Boolean for grouped query attention support
@@ -1462,14 +1462,14 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14621462
Attributes:
14631463
score_mod_fn: Optional function symbol reference for score modification
14641464
mask_mod_fn: Optional function symbol reference for mask modification
1465-
1465+
14661466
TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)
14671467

14681468
Returns:
14691469
output: Result tensor [B, H, M, Ev]
14701470
logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
14711471
}];
1472-
1472+
14731473
let arguments = (ins
14741474
AnyTorchTensorType:$query,
14751475
AnyTorchTensorType:$key,
@@ -1480,12 +1480,12 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14801480
OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
14811481
OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
14821482
);
1483-
1483+
14841484
let results = (outs
14851485
AnyTorchTensorType:$output,
14861486
AnyTorchOptionalTensorType:$logsumexp
14871487
);
1488-
1488+
14891489
let hasCustomAssemblyFormat = 1;
14901490
let extraClassDefinition = [{
14911491
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {

0 commit comments

Comments
 (0)