Skip to content

Commit 7d5e2f6

Browse files
authored
Refactor Hadamard matrices loading (#3344)
* Refactor Hadamard matrices loading * lint * Generate Hadamard pickle during setup * lint
1 parent 8d546c8 commit 7d5e2f6

File tree

4 files changed

+145
-99297
lines changed

4 files changed

+145
-99297
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
122122
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
123123
torch/version.py
124124
minifier_launcher.py
125+
torchao/prototype/spinquant/_hadamard_matrices.pkl
125126
# Root level file used in CI to specify certain env configs.
126127
# E.g., see .circleci/config.yaml
127128
env

setup.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55

66
import copy
77
import glob
8+
import json
89
import os
10+
import pickle
911
import subprocess
1012
import sys
1113
import time
1214
from datetime import datetime
15+
from pathlib import Path
1316
from typing import List, Optional
1417

1518
from setuptools import Extension, find_packages, setup
19+
from setuptools.command.build_py import build_py as build_py_orig
1620

1721
current_date = datetime.now().strftime("%Y%m%d")
1822

@@ -40,6 +44,41 @@ def read_version(file_path="version.txt"):
4044
return file.readline().strip()
4145

4246

47+
SPINQUANT_REL_PATH = Path("torchao") / "prototype" / "spinquant"
48+
HADAMARD_JSON = "_hadamard_matrices.json"
49+
HADAMARD_PKL = "_hadamard_matrices.pkl"
50+
51+
52+
def ensure_hadamard_pickle(root_dir: Optional[Path] = None, *, quiet: bool = True):
53+
"""
54+
Guarantee that the Hadamard pickle exists (and is newer than the JSON source)
55+
so setup.py packaging has an observable, reproducible rule.
56+
"""
57+
58+
base_dir = (
59+
Path(root_dir) if root_dir is not None else Path(__file__).parent.resolve()
60+
)
61+
spinquant_dir = base_dir / SPINQUANT_REL_PATH
62+
json_path = spinquant_dir / HADAMARD_JSON
63+
if not json_path.exists():
64+
return
65+
66+
pkl_path = spinquant_dir / HADAMARD_PKL
67+
if pkl_path.exists() and pkl_path.stat().st_mtime >= json_path.stat().st_mtime:
68+
return
69+
70+
with json_path.open("r") as source:
71+
raw_matrices = json.load(source)
72+
73+
pkl_path.parent.mkdir(parents=True, exist_ok=True)
74+
with pkl_path.open("wb") as sink:
75+
pickle.dump(raw_matrices, sink, protocol=pickle.HIGHEST_PROTOCOL)
76+
77+
if not quiet:
78+
rel_path = pkl_path.relative_to(base_dir)
79+
print(f"[setup.py] regenerated {rel_path} from JSON source")
80+
81+
4382
# Use Git commit ID if VERSION_SUFFIX is not set
4483
version_suffix = os.getenv("VERSION_SUFFIX")
4584
if version_suffix is None:
@@ -763,6 +802,12 @@ def bool_to_on_off(value):
763802
return ext_modules
764803

765804

805+
class TorchAOBuildPy(build_py_orig):
806+
def run(self):
807+
ensure_hadamard_pickle()
808+
super().run()
809+
810+
766811
# Only check submodules if we're going to build C++ extensions
767812
if use_cpp != "0":
768813
check_submodules()
@@ -774,13 +819,17 @@ def bool_to_on_off(value):
774819
include_package_data=True,
775820
package_data={
776821
"torchao.kernel.configs": ["*.pkl"],
822+
"torchao.prototype.spinquant": [
823+
"_hadamard_matrices.json",
824+
"_hadamard_matrices.pkl",
825+
],
777826
},
778827
ext_modules=get_extensions(),
779828
extras_require={"dev": read_requirements("dev-requirements.txt")},
780829
description="Package for applying ao techniques to GPU models",
781830
long_description=open("README.md", encoding="utf-8").read(),
782831
long_description_content_type="text/markdown",
783832
url="https://github.com/pytorch/ao",
784-
cmdclass={"build_ext": TorchAOBuildExt},
833+
cmdclass={"build_ext": TorchAOBuildExt, "build_py": TorchAOBuildPy},
785834
options={"bdist_wheel": {"py_limited_api": "cp310"}},
786835
)

torchao/prototype/spinquant/_hadamard_matrices.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)