Skip to content

Conversation

@wrongtest-intellif
Copy link
Contributor

A prototyping change to add ForNode::step

It try add minimal codes to run following naive test cases.

import tvm
import numpy as np
from tvm.script import tir as T

@T.prim_func
def function(A: T.Buffer[(1024)], B: T.Buffer[(1024)], C: T.Buffer[(1024)]):
    for i in range(0, 100, 3):
        C[i] = A[i] + B[i]
    
print(function)
lib = tvm.compile(function, target="c")
print(lib.mod.inspect_source())

lib2 = tvm.compile(function, target="llvm")

a = np.random.uniform(1, 100, [1024]).astype("float32")
b = np.random.uniform(1, 100, [1024]).astype("float32")
c = np.zeros([1024]).astype("float32")
lib(a, b, c)
c[:] = 0
print(c[:])
lib2(a, b, c)
print(c[:])

The aspects to check for a real roadmap may be

  1. Roundtrip support for TIR tvmscript grammar
  2. Correctness of TIR lowering pipeline
    • For all transformations and analysis tools, either it make adaptions to non-consecutive loop iteration indices, or loop canonicalization required.
    • Ensure the original ForNode::step is not dropped by mutations on ForNode.
  3. Correctness of TensorIR schedule and MetaSchedule
    • Since many primitives depend on affine bindings. Loop canonicalization is required.
  4. CodeGen support
    • Check mainstream targets could support the loop step.
  5. Compatibility issues
    • Try to argue that the change would not affect existing works, since ForNode is an important construction in TVM.

@wrongtest-intellif wrongtest-intellif marked this pull request as draft November 5, 2025 06:56
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wrongtest-intellif, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a step attribute to the ForNode in TVM's Tensor IR (TIR), enabling the representation and generation of loops with non-unit iteration steps. This foundational change impacts various components, from the core IR definition and Python scripting interface to code generation backends and scheduling primitives, ensuring that loops with custom steps are correctly parsed, lowered, and compiled across TVM.

Highlights

  • New step attribute for ForNode: Introduced an optional step attribute to the ForNode in TVM's Tensor IR (TIR), allowing loops to iterate with non-unit steps (e.g., for i in range(0, 100, 3)).
  • Python API and Scripting Support: Updated the Python tvm.script.ir_builder.tir module to accept step arguments for serial, parallel, vectorized, and unroll loop types. The tvm.script.parser.tir now handles Python's range sugar with a step argument, mapping it to the new TIR step attribute.
  • Code Generation Updates: Modified various code generation backends (LLVM, C, WebGPU, SPIRV) to correctly interpret and generate code for loops that include a custom step value. This ensures that the new loop semantics are respected during compilation.
  • TIR Schedule and Transform Compatibility: Adjusted numerous TIR scheduling primitives and transformation passes to either propagate the new step attribute or utilize the newly introduced For::ForSimple constructor, which implicitly assumes a trivial (unit) step. This maintains compatibility and correctness across the TIR pipeline.
  • New Utility and Tests: Added a HasTrivialStep() method to ForNode for easier checking of unit steps. New test cases were introduced in test_tir_nodes.py and test_tvmscript_parser_tir.py to validate the functionality of loops with custom steps.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a step attribute to ForNode, which is a significant and welcome enhancement to TIR. The changes are extensive, touching many parts of the codebase from IR definition, parsers, to various codegen backends and transformation passes. The use of For::ForSimple and CopyOnWrite patterns to manage the new field is well-executed and helps maintain backward compatibility and code clarity. The addition of tests for parsing and round-tripping is also appreciated.

I've found a few issues, mainly related to inconsistencies in codegen backends and a potential bug in the SPIR-V codegen. Addressing these will make this foundational PR more robust.

Comment on lines 681 to 687
if (op->HasTrivialStep()) {
step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
: builder_->UIntImm(loop_var.stype, 1);
} else {
step = MakeValue(tir::cast(op->extent->dtype, *op->step));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a bug in the step calculation. The code builder_->IntImm(loop_var.stype, 1) uses loop_var, but the local loop_var of spirv::PhiValue type is defined later in the function. This means it's likely using the member variable this->loop_var from an outer loop, which is incorrect. The type of the current loop variable op->loop_var should be used instead.

Suggested change
if (op->HasTrivialStep()) {
step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
: builder_->UIntImm(loop_var.stype, 1);
} else {
step = MakeValue(tir::cast(op->extent->dtype, *op->step));
}
if (op->HasTrivialStep()) {
spirv::SType stype = builder_->GetSType(op->loop_var.dtype());
step = op->loop_var.dtype().is_int() ? builder_->IntImm(stype, 1)
: builder_->UIntImm(stype, 1);
} else {
step = MakeValue(tir::cast(op->extent->dtype, *op->step));
}

Comment on lines 1122 to 1139
void CodeGenC::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
std::string step = op->step.has_value()? PrintExpr(*op->step) : "";
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n";
stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ";
if (step.empty()) {
stream << "++" << vid;
} else {
stream << vid << " += " << step;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The C codegen for ForNode still assumes that the loop starts from 0, as indicated by the ICHECK(is_zero(op->min)) and the hardcoded vid = 0. Other backends like LLVM and SPIR-V have been updated to support non-zero min. For consistency and to fully support generalized loops, this backend should also be updated to handle non-zero min. This would involve removing the ICHECK and using op->min for loop initialization and in the loop condition.

  std::string min_str = PrintExpr(op->min);
  std::string extent_str = PrintExpr(op->extent);
  std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
  stream << ' ' << vid << " = " << min_str << "; " << vid << " < (" << min_str << " + " << extent_str << "); ";
  if (step_str.empty()) {
    stream << "++" << vid;
  } else {
    stream << vid << " += " << step_str;
  }
  stream << ") {\n";
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);

Comment on lines 669 to 686
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
std::string step = op->step.has_value()? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
stream << " = 0; " << vid << " < " << extent << "; " << vid;
if (step.empty()) {
stream << "++";
} else {
stream << " += " << step;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the C codegen, the WebGPU codegen for ForNode still assumes min is zero. This should be updated to support non-zero min for consistency with other backends like LLVM and SPIR-V.

  std::string min_str = PrintExpr(op->min);
  std::string extent_str = PrintExpr(op->extent);
  std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
  std::string vid = AllocVarID(op->loop_var.get());
  PrintIndent();
  stream << "for (var " << vid << " : ";
  PrintType(op->loop_var.dtype(), stream);
  stream << " = " << min_str << "; " << vid << " < (" << min_str << " + " << extent_str << "); " << vid;
  if (step_str.empty()) {
    stream << "++";
  } else {
    stream << " += " << step_str;
  }
  stream << ") {\n";
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);

Comment on lines +333 to +334
new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
return new_loop;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using CopyOnWrite here is a great way to preserve the new step field and any other future fields on the ForNode. This pattern has been applied consistently across the PR, which is excellent.

Comment on lines 728 to 729
} else if (const auto scan_op = op.as<te::ScanOp>()) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if branch for te::ScanOp is empty. If this is a placeholder for future work, it would be better to add a // TODO comment explaining the intent or remove it for now to avoid dead code.

@tqchen tqchen marked this pull request as ready for review November 5, 2025 17:33
@Hzfengsy
Copy link
Member

Hzfengsy commented Nov 7, 2025

cc @LeiWang1999

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants