Skip to content

Commit 675cb15

Browse files
chr1sj0nesjax authors
authored andcommitted
[pallas:gpu] Directly use Triton built-in function lowering for max_contiguous and multiple_of.
PiperOrigin-RevId: 573871324
1 parent ac2c228 commit 675cb15

File tree

1 file changed

+7
-19
lines changed

1 file changed

+7
-19
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -337,24 +337,6 @@ def _atomic_lowering_rule(
337337
triton_lowering_rules[primitives.atomic_rmw_p] = _atomic_lowering_rule
338338

339339

340-
def _max_contiguous_lowering_rule(ctx: TritonLoweringRuleContext, x, *, values):
341-
values = [tl.constexpr(v) for v in values]
342-
return tl.max_contiguous(x, values, _builder=ctx.builder)
343-
344-
345-
triton_lowering_rules[primitives.max_contiguous_p] = (
346-
_max_contiguous_lowering_rule
347-
)
348-
349-
350-
def _multiple_of_lowering_rule(ctx: TritonLoweringRuleContext, x, *, values):
351-
values = [tl.constexpr(v) for v in values]
352-
return tl.multiple_of(x, values, _builder=ctx.builder)
353-
354-
355-
triton_lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule
356-
357-
358340
_TRITON_FN_MAPPING = {
359341
# Unary ops.
360342
lax.neg_p: tl.semantic.minus,
@@ -407,12 +389,18 @@ def _multiple_of_lowering_rule(ctx: TritonLoweringRuleContext, x, *, values):
407389
ad_util.add_any_p: tl.semantic.add,
408390
# Other ops.
409391
primitives.atomic_cas_p: tl.atomic_cas,
392+
primitives.max_contiguous_p: tl.max_contiguous,
393+
primitives.multiple_of_p: tl.multiple_of,
410394
}
411395

412396

413397
for primitive, fn in _TRITON_FN_MAPPING.items():
414398
if tl.core.is_builtin(fn):
415-
rule = lambda ctx, *args, fn=fn: fn(*args, _builder=ctx.builder)
399+
400+
def rule(ctx, *args, fn=fn, **kwargs):
401+
kwargs = tree_util.tree_map(tl.constexpr, kwargs)
402+
return fn(*args, **kwargs, _builder=ctx.builder)
403+
416404
else:
417405
rule = lambda ctx, *args, fn=fn: fn(*args, ctx.builder)
418406

0 commit comments

Comments
 (0)