Skip to content

Commit a243955

Browse files
authored
Fixed editable install to depend on CuTeDSL/requirements.txt (#2768)
To guarantee wheel version alignment of the source code.
1 parent bd96096 commit a243955

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

python/CuTeDSL/prep_editable_install.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,27 @@ class CutlassDSLSetupError(Exception):
4040
pass
4141

4242

43+
def get_package_spec(requirements_path: Optional[Path] = None) -> str:
44+
"""
45+
Return the pip requirement spec for nvidia-cutlass-dsl from requirements.txt.
46+
47+
If anything goes wrong (file not found, parse failure, line missing),
48+
return PACKAGE_NAME as a safe default.
49+
"""
50+
try:
51+
req_path = requirements_path or Path(__file__).with_name("requirements.txt")
52+
with open(req_path, "r", encoding="utf-8") as f:
53+
for raw_line in f:
54+
line = raw_line.strip()
55+
if not line or line.startswith("#"):
56+
continue
57+
if line.lower().startswith(PACKAGE_NAME):
58+
return line.split("#", 1)[0].strip()
59+
except Exception:
60+
pass
61+
return PACKAGE_NAME
62+
63+
4364
def download_wheel(temp_dir: Path) -> Path:
4465
"""
4566
Download the nvidia-cutlass-dsl wheel to a temporary directory.
@@ -53,7 +74,10 @@ def download_wheel(temp_dir: Path) -> Path:
5374
Raises:
5475
CutlassDSLSetupError: If download fails or wheel not found
5576
"""
56-
logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}")
77+
# Resolve package spec from requirements, or fall back to PACKAGE_NAME
78+
package_spec = get_package_spec()
79+
80+
logger.info(f"Downloading {package_spec} wheel to {temp_dir}")
5781

5882
try:
5983
subprocess.check_call(
@@ -63,7 +87,7 @@ def download_wheel(temp_dir: Path) -> Path:
6387
"pip",
6488
"download",
6589
"--no-deps",
66-
PACKAGE_NAME,
90+
package_spec,
6791
"--dest",
6892
str(temp_dir),
6993
],
@@ -79,7 +103,7 @@ def download_wheel(temp_dir: Path) -> Path:
79103
raise CutlassDSLSetupError(error_msg)
80104

81105
# Find the downloaded wheel file
82-
wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl"
106+
wheel_pattern = f"*.whl"
83107
wheel_files = list(temp_dir.glob(wheel_pattern))
84108
if not wheel_files:
85109
raise CutlassDSLSetupError(
@@ -108,7 +132,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str:
108132
# Construct version regex from package name
109133
# Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl
110134
package_pattern = PACKAGE_NAME.replace("-", "_")
111-
version_regex = rf"{re.escape(package_pattern)}-([^-]+)-"
135+
version_regex = rf"{re.escape(package_pattern)}-([^-]+)"
112136
version_match = re.match(version_regex, wheel_filename)
113137

114138
if version_match:
@@ -132,10 +156,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str:
132156

133157
return dev_version
134158
else:
135-
raise CutlassDSLSetupError(
136-
f"Could not parse version from wheel filename: {wheel_filename}"
137-
)
138-
159+
return "9.9.9.dev0"
139160

140161
def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None:
141162
"""

0 commit comments

Comments
 (0)