@@ -337,24 +337,6 @@ def _atomic_lowering_rule(
337337triton_lowering_rules [primitives .atomic_rmw_p ] = _atomic_lowering_rule
338338
339339
340- def _max_contiguous_lowering_rule (ctx : TritonLoweringRuleContext , x , * , values ):
341- values = [tl .constexpr (v ) for v in values ]
342- return tl .max_contiguous (x , values , _builder = ctx .builder )
343-
344-
345- triton_lowering_rules [primitives .max_contiguous_p ] = (
346- _max_contiguous_lowering_rule
347- )
348-
349-
350- def _multiple_of_lowering_rule (ctx : TritonLoweringRuleContext , x , * , values ):
351- values = [tl .constexpr (v ) for v in values ]
352- return tl .multiple_of (x , values , _builder = ctx .builder )
353-
354-
355- triton_lowering_rules [primitives .multiple_of_p ] = _multiple_of_lowering_rule
356-
357-
358340_TRITON_FN_MAPPING = {
359341 # Unary ops.
360342 lax .neg_p : tl .semantic .minus ,
@@ -407,12 +389,18 @@ def _multiple_of_lowering_rule(ctx: TritonLoweringRuleContext, x, *, values):
407389 ad_util .add_any_p : tl .semantic .add ,
408390 # Other ops.
409391 primitives .atomic_cas_p : tl .atomic_cas ,
392+ primitives .max_contiguous_p : tl .max_contiguous ,
393+ primitives .multiple_of_p : tl .multiple_of ,
410394}
411395
412396
413397for primitive , fn in _TRITON_FN_MAPPING .items ():
414398 if tl .core .is_builtin (fn ):
415- rule = lambda ctx , * args , fn = fn : fn (* args , _builder = ctx .builder )
399+
400+ def rule (ctx , * args , fn = fn , ** kwargs ):
401+ kwargs = tree_util .tree_map (tl .constexpr , kwargs )
402+ return fn (* args , ** kwargs , _builder = ctx .builder )
403+
416404 else :
417405 rule = lambda ctx , * args , fn = fn : fn (* args , ctx .builder )
418406
0 commit comments