Skip to content

Commit 70a961f

Browse files
committed
remove dtype and replace with precision
1 parent 1b5bd06 commit 70a961f

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchao/quantization/linear_quant_modules.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ def __init__(
108108
self.precision = precision
109109
self.scales_precision = scales_precision
110110

111-
if dtype is not None:
112-
raise ValueError("Please specify 'precision' instead of 'dtype'")
113-
114111
assert out_features % 8 == 0, "require out_features % 8 == 0"
115112
assert in_features % (inner_k_tiles * 16) == 0, (
116113
"require in_features % (innerKTiles * 16) == 0"
@@ -123,7 +120,7 @@ def __init__(
123120
out_features,
124121
in_features // 2,
125122
),
126-
dtype=torch.uint8,
123+
precision=torch.uint8,
127124
device=device,
128125
),
129126
)
@@ -137,11 +134,11 @@ def __init__(
137134
32,
138135
inner_k_tiles // 2,
139136
),
140-
dtype=torch.int32,
137+
precision=torch.int32,
141138
device=device,
142139
),
143140
)
144-
self.dtype = dtype
141+
self.precision = precision
145142
self.register_buffer(
146143
"scales_and_zeros",
147144
torch.zeros(

0 commit comments

Comments
 (0)