@@ -1450,10 +1450,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14501450 let summary = "Generated op for `aten::flex_attention`";
14511451 let description = [{
14521452 FlexAttention operation with flexible block-sparse attention patterns.
1453-
1453+
14541454 Args:
14551455 query: Query tensor [B, H, M, K]
1456- key: Key tensor [B, H, N, K]
1456+ key: Key tensor [B, H, N, K]
14571457 value: Value tensor [B, H, N, Ev]
14581458 scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
14591459 enable_gqa: Boolean for grouped query attention support
@@ -1462,14 +1462,14 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14621462 Attributes:
14631463 score_mod_fn: Optional function symbol reference for score modification
14641464 mask_mod_fn: Optional function symbol reference for mask modification
1465-
1465+
14661466 TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)
14671467
14681468 Returns:
14691469 output: Result tensor [B, H, M, Ev]
14701470 logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
14711471 }];
1472-
1472+
14731473 let arguments = (ins
14741474 AnyTorchTensorType:$query,
14751475 AnyTorchTensorType:$key,
@@ -1480,12 +1480,12 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
14801480 OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
14811481 OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
14821482 );
1483-
1483+
14841484 let results = (outs
14851485 AnyTorchTensorType:$output,
14861486 AnyTorchOptionalTensorType:$logsumexp
14871487 );
1488-
1488+
14891489 let hasCustomAssemblyFormat = 1;
14901490 let extraClassDefinition = [{
14911491 ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
0 commit comments