Skip to content

Commit 715e3d0

Browse files
committed
replace dtype with precision in Int8DynActInt4WeightLinear
1 parent 70a961f commit 715e3d0

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchao/quantization/linear_quant_modules.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,26 +427,23 @@ def __init__(
427427
# that his module represents.
428428
self.precision = precision
429429

430-
if dtype is not None:
431-
raise ValueError("Please specify 'precision' instead of 'dtype'")
432-
433430
# currently storing unpacked int8 weights
434431
self.register_buffer(
435432
"weight",
436-
torch.zeros((out_features, in_features), dtype=torch.int8),
433+
torch.zeros((out_features, in_features), precision=torch.int8),
437434
)
438435
self.register_buffer(
439436
"scales",
440437
torch.zeros(
441438
(out_features, in_features // groupsize),
442-
dtype=scales_precision,
439+
precision=scales_precision,
443440
),
444441
)
445442
self.register_buffer(
446443
"zeros",
447444
torch.zeros(
448445
(out_features, in_features // groupsize),
449-
dtype=scales_precision,
446+
precision=scales_precision,
450447
),
451448
)
452449

0 commit comments

Comments
 (0)