Skip to content

TypeError when using string as callable in Triton kernel with tl.constexpr #5070

@CaesarKingW

Description

@CaesarKingW

Describe the bug
When running the Triton kernel with a string passed as operation_function via tl.constexpr, a TypeError occurs when trying to call the string as a function.

https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py#L458

     # do fusion operation if need
            if operation_function is not None:
                operation_dense_val = tl.load(
                    operation_dense + block_offset, mask=dense_mask, other=0.0
                )
                jagged_val = operation_function(operation_dense_val, jagged_val)

            # load value into empty dense
            tl.store(output_dense_ptr + block_offset, jagged_val, mask=dense_mask)

Error Message

jagged_val = operation_function(operation_dense_val, jagged_val)
^
TypeError: 'str' object is not callable

To Reproduce
Steps to reproduce the behavior:

  1. Call the Triton kernel (e.g., jagged_to_dense...) with a string argument (like "add") for the operation_function parameter, which is declared as tl.constexpr
  2. The kernel attempts to execute operation_function(operation_dense_val, jagged_val)
  3. The TypeError occurs because a string object is not callable

Code Example

import torch
import fbgemm_gpu

jagged_values = torch.randn(10, 5).to('cuda')
jagged_offsets = [torch.tensor([0,3,6,10], dtype=torch.long).to('cuda'),]
jagged_max_lengths = [4]
padding_value = 2.0 
operation_function = "add"
operation_dense = torch.randn(4, 4, 5).to('mlu')
jagged_to_dense(jagged_values, jagged_offsets, jagged_max_lengths, padding_value)

Expected behavior
The expected behavior depends on the intended design:

If operation_function is meant to be a string representing an operation (like "add", "mul"), the code should internally map this string to the corresponding Triton function (e.g., tl.add, tl.mul) before calling it

If operation_function is intended to be directly callable, the function object itself (e.g., tl.add) should be passed via tl.constexpr, not a string

Possible solutions

                if operation_function == "add":
                    jagged_val = tensor_elementwise_add(operation_dense_val, jagged_val)
                else:
                    jagged_val = tensor_elementwise_mul(operation_dense_val, jagged_val)

Environment

FBGEMM: fbgemm-gpu version: 1.2.0+cu126

PyTorch version: 2.7

Triton version: 3.3.1

Python version: 3.11.13

Additional context
The error occurs because in Python, a string object cannot be called like a function. The code currently treats the string operation_function as if it were a callable function, leading to the TypeError. The solution involves ensuring that a callable object is used in the function call, either by converting the string to the intended function internally or by passing the function directly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions