File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed
Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -108,9 +108,6 @@ def __init__(
108108 self .precision = precision
109109 self .scales_precision = scales_precision
110110
111- if dtype is not None :
112- raise ValueError ("Please specify 'precision' instead of 'dtype'" )
113-
114111 assert out_features % 8 == 0 , "require out_features % 8 == 0"
115112 assert in_features % (inner_k_tiles * 16 ) == 0 , (
116113 "require in_features % (innerKTiles * 16) == 0"
@@ -123,7 +120,7 @@ def __init__(
123120 out_features ,
124121 in_features // 2 ,
125122 ),
126- dtype = torch .uint8 ,
123+ precision = torch .uint8 ,
127124 device = device ,
128125 ),
129126 )
@@ -137,11 +134,11 @@ def __init__(
137134 32 ,
138135 inner_k_tiles // 2 ,
139136 ),
140- dtype = torch .int32 ,
137+ precision = torch .int32 ,
141138 device = device ,
142139 ),
143140 )
144- self .dtype = dtype
141+ self .precision = precision
145142 self .register_buffer (
146143 "scales_and_zeros" ,
147144 torch .zeros (
You can’t perform that action at this time.
0 commit comments