-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR] Add step attribute to ForNode (Initial codes) #18421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TIR] Add step attribute to ForNode (Initial codes) #18421
Conversation
Summary of ChangesHello @wrongtest-intellif, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces a step attribute to ForNode, which is a significant and welcome enhancement to TIR. The changes are extensive, touching many parts of the codebase from IR definition, parsers, to various codegen backends and transformation passes. The use of For::ForSimple and CopyOnWrite patterns to manage the new field is well-executed and helps maintain backward compatibility and code clarity. The addition of tests for parsing and round-tripping is also appreciated.
I've found a few issues, mainly related to inconsistencies in codegen backends and a potential bug in the SPIR-V codegen. Addressing these will make this foundational PR more robust.
| if (op->HasTrivialStep()) { | ||
| step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) | ||
| : builder_->UIntImm(loop_var.stype, 1); | ||
| } else { | ||
| step = MakeValue(tir::cast(op->extent->dtype, *op->step)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a bug in the step calculation. The code builder_->IntImm(loop_var.stype, 1) uses loop_var, but the local loop_var of spirv::PhiValue type is defined later in the function. This means it's likely using the member variable this->loop_var from an outer loop, which is incorrect. The type of the current loop variable op->loop_var should be used instead.
| if (op->HasTrivialStep()) { | |
| step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) | |
| : builder_->UIntImm(loop_var.stype, 1); | |
| } else { | |
| step = MakeValue(tir::cast(op->extent->dtype, *op->step)); | |
| } | |
| if (op->HasTrivialStep()) { | |
| spirv::SType stype = builder_->GetSType(op->loop_var.dtype()); | |
| step = op->loop_var.dtype().is_int() ? builder_->IntImm(stype, 1) | |
| : builder_->UIntImm(stype, 1); | |
| } else { | |
| step = MakeValue(tir::cast(op->extent->dtype, *op->step)); | |
| } |
| void CodeGenC::VisitStmt_(const ForNode* op) { | ||
| std::string extent = PrintExpr(op->extent); | ||
| std::string step = op->step.has_value()? PrintExpr(*op->step) : ""; | ||
| PrintIndent(); | ||
| std::string vid = AllocVarID(op->loop_var.get()); | ||
| ICHECK(is_zero(op->min)); | ||
| stream << "for ("; | ||
| PrintType(op->loop_var.dtype(), stream); | ||
| stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; | ||
| stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; "; | ||
| if (step.empty()) { | ||
| stream << "++" << vid; | ||
| } else { | ||
| stream << vid << " += " << step; | ||
| } | ||
| stream << ") {\n"; | ||
| int for_scope = BeginScope(); | ||
| PrintStmt(op->body); | ||
| this->EndScope(for_scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The C codegen for ForNode still assumes that the loop starts from 0, as indicated by the ICHECK(is_zero(op->min)) and the hardcoded vid = 0. Other backends like LLVM and SPIR-V have been updated to support non-zero min. For consistency and to fully support generalized loops, this backend should also be updated to handle non-zero min. This would involve removing the ICHECK and using op->min for loop initialization and in the loop condition.
std::string min_str = PrintExpr(op->min);
std::string extent_str = PrintExpr(op->extent);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << min_str << "; " << vid << " < (" << min_str << " + " << extent_str << "); ";
if (step_str.empty()) {
stream << "++" << vid;
} else {
stream << vid << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);| void CodeGenWebGPU::VisitStmt_(const ForNode* op) { | ||
| std::string extent = PrintExpr(op->extent); | ||
| std::string step = op->step.has_value()? PrintExpr(*op->step) : ""; | ||
| std::string vid = AllocVarID(op->loop_var.get()); | ||
| ICHECK(is_zero(op->min)); | ||
| PrintIndent(); | ||
| stream << "for (var " << vid << " : "; | ||
| PrintType(op->loop_var.dtype(), stream); | ||
| stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; | ||
| stream << " = 0; " << vid << " < " << extent << "; " << vid; | ||
| if (step.empty()) { | ||
| stream << "++"; | ||
| } else { | ||
| stream << " += " << step; | ||
| } | ||
| stream << ") {\n"; | ||
| int for_scope = BeginScope(); | ||
| PrintStmt(op->body); | ||
| this->EndScope(for_scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the C codegen, the WebGPU codegen for ForNode still assumes min is zero. This should be updated to support non-zero min for consistency with other backends like LLVM and SPIR-V.
std::string min_str = PrintExpr(op->min);
std::string extent_str = PrintExpr(op->extent);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = " << min_str << "; " << vid << " < (" << min_str << " + " << extent_str << "); " << vid;
if (step_str.empty()) {
stream << "++";
} else {
stream << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);| new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard); | ||
| return new_loop; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/te/operation/create_primfunc.cc
Outdated
| } else if (const auto scan_op = op.as<te::ScanOp>()) { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
cc @LeiWang1999 |
efe1a4e to
fd57ac5
Compare
A prototyping change to add
ForNode::stepIt try add minimal codes to run following naive test cases.
The aspects to check for a real roadmap may be
ForNode::stepis not dropped by mutations onForNode.ForNodeis an important construction in TVM.