@@ -152,6 +152,23 @@ def forward(self, input_tensor):
152152
153153NORM2FN = {"layer_norm" : nn .LayerNorm , "no_norm" : NoNorm }
154154
155+ class QATEmbeddingTransformation (nn .Module ):
156+ def __init__ (self , embedded_input_size , hidden_size ):
157+ super ().__init__ ()
158+
159+ # Behaves like normal Linear module unless a SparseML QuantizationModifier
160+ # is initialized.
161+ # When initialized, does not quantize inputs.
162+ # Only weights are quantized (inputs come quantized from embeddings)
163+ self .linear = nn .Linear (embedded_input_size , hidden_size )
164+ self .wrap_qat = True
165+ self .qat_wrapper_kwargs = {
166+ "num_inputs" : 0 ,
167+ "num_outputs" : 1 ,
168+ }
169+
170+ def forward (self , x : torch .Tensor ):
171+ return self .linear (x )
155172
156173class MobileBertEmbeddings (nn .Module ):
157174 """Construct the embeddings from word, position and token_type embeddings."""
@@ -168,7 +185,7 @@ def __init__(self, config):
168185
169186 embed_dim_multiplier = 3 if self .trigram_input else 1
170187 embedded_input_size = self .embedding_size * embed_dim_multiplier
171- self .embedding_transformation = nn . Linear (embedded_input_size , config .hidden_size )
188+ self .embedding_transformation = QATEmbeddingTransformation (embedded_input_size , config .hidden_size )
172189
173190 self .LayerNorm = NORM2FN [config .normalization_type ](config .hidden_size )
174191 self .dropout = nn .Dropout (config .hidden_dropout_prob )
0 commit comments