Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit cc9aacf

Browse files
authored
Mobilebert QAT (#55)
* Remove duplicate quantization of vocabulary. * Remove duplicate quantization of vocabulary.
1 parent 78694d9 commit cc9aacf

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

src/transformers/models/mobilebert/modeling_mobilebert.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,23 @@ def forward(self, input_tensor):
152152

153153
NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
154154

155+
class QATEmbeddingTransformation(nn.Module):
156+
def __init__(self, embedded_input_size, hidden_size):
157+
super().__init__()
158+
159+
# Behaves like normal Linear module unless a SparseML QuantizationModifier
160+
# is initialized.
161+
# When initialized, does not quantize inputs.
162+
# Only weights are quantized (inputs come quantized from embeddings)
163+
self.linear = nn.Linear(embedded_input_size, hidden_size)
164+
self.wrap_qat = True
165+
self.qat_wrapper_kwargs = {
166+
"num_inputs": 0,
167+
"num_outputs": 1,
168+
}
169+
170+
def forward(self, x: torch.Tensor):
171+
return self.linear(x)
155172

156173
class MobileBertEmbeddings(nn.Module):
157174
"""Construct the embeddings from word, position and token_type embeddings."""
@@ -168,7 +185,7 @@ def __init__(self, config):
168185

169186
embed_dim_multiplier = 3 if self.trigram_input else 1
170187
embedded_input_size = self.embedding_size * embed_dim_multiplier
171-
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
188+
self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size)
172189

173190
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
174191
self.dropout = nn.Dropout(config.hidden_dropout_prob)

0 commit comments

Comments
 (0)