55
66import copy
77import glob
8+ import json
89import os
10+ import pickle
911import subprocess
1012import sys
1113import time
1214from datetime import datetime
15+ from pathlib import Path
1316from typing import List , Optional
1417
1518from setuptools import Extension , find_packages , setup
19+ from setuptools .command .build_py import build_py as build_py_orig
1620
1721current_date = datetime .now ().strftime ("%Y%m%d" )
1822
@@ -40,6 +44,41 @@ def read_version(file_path="version.txt"):
4044 return file .readline ().strip ()
4145
4246
47+ SPINQUANT_REL_PATH = Path ("torchao" ) / "prototype" / "spinquant"
48+ HADAMARD_JSON = "_hadamard_matrices.json"
49+ HADAMARD_PKL = "_hadamard_matrices.pkl"
50+
51+
52+ def ensure_hadamard_pickle (root_dir : Optional [Path ] = None , * , quiet : bool = True ):
53+ """
54+ Guarantee that the Hadamard pickle exists (and is newer than the JSON source)
55+ so setup.py packaging has an observable, reproducible rule.
56+ """
57+
58+ base_dir = (
59+ Path (root_dir ) if root_dir is not None else Path (__file__ ).parent .resolve ()
60+ )
61+ spinquant_dir = base_dir / SPINQUANT_REL_PATH
62+ json_path = spinquant_dir / HADAMARD_JSON
63+ if not json_path .exists ():
64+ return
65+
66+ pkl_path = spinquant_dir / HADAMARD_PKL
67+ if pkl_path .exists () and pkl_path .stat ().st_mtime >= json_path .stat ().st_mtime :
68+ return
69+
70+ with json_path .open ("r" ) as source :
71+ raw_matrices = json .load (source )
72+
73+ pkl_path .parent .mkdir (parents = True , exist_ok = True )
74+ with pkl_path .open ("wb" ) as sink :
75+ pickle .dump (raw_matrices , sink , protocol = pickle .HIGHEST_PROTOCOL )
76+
77+ if not quiet :
78+ rel_path = pkl_path .relative_to (base_dir )
79+ print (f"[setup.py] regenerated { rel_path } from JSON source" )
80+
81+
4382# Use Git commit ID if VERSION_SUFFIX is not set
4483version_suffix = os .getenv ("VERSION_SUFFIX" )
4584if version_suffix is None :
@@ -763,6 +802,12 @@ def bool_to_on_off(value):
763802 return ext_modules
764803
765804
805+ class TorchAOBuildPy (build_py_orig ):
806+ def run (self ):
807+ ensure_hadamard_pickle ()
808+ super ().run ()
809+
810+
766811# Only check submodules if we're going to build C++ extensions
767812if use_cpp != "0" :
768813 check_submodules ()
@@ -774,13 +819,17 @@ def bool_to_on_off(value):
774819 include_package_data = True ,
775820 package_data = {
776821 "torchao.kernel.configs" : ["*.pkl" ],
822+ "torchao.prototype.spinquant" : [
823+ "_hadamard_matrices.json" ,
824+ "_hadamard_matrices.pkl" ,
825+ ],
777826 },
778827 ext_modules = get_extensions (),
779828 extras_require = {"dev" : read_requirements ("dev-requirements.txt" )},
780829 description = "Package for applying ao techniques to GPU models" ,
781830 long_description = open ("README.md" , encoding = "utf-8" ).read (),
782831 long_description_content_type = "text/markdown" ,
783832 url = "https://github.com/pytorch/ao" ,
784- cmdclass = {"build_ext" : TorchAOBuildExt },
833+ cmdclass = {"build_ext" : TorchAOBuildExt , "build_py" : TorchAOBuildPy },
785834 options = {"bdist_wheel" : {"py_limited_api" : "cp310" }},
786835)
0 commit comments