Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 38ae788

Browse files
authored
OPT with quantizable MatMuls (#85)
1 parent 2aca427 commit 38ae788

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

src/transformers/models/opt/modeling_opt.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,46 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int
120120
return super().forward(positions + self.offset)
121121

122122

123+
class BMMLeftInput_QK(nn.Identity):
124+
...
125+
126+
127+
class BMMRightInput_QK(nn.Identity):
128+
...
129+
130+
131+
class BMMOutput_QK(nn.Identity):
132+
...
133+
134+
135+
class BMMLeftInput_PV(nn.Identity):
136+
...
137+
138+
139+
class BMMRightInput_PV(nn.Identity):
140+
...
141+
142+
143+
class BMMOutput_PV(nn.Identity):
144+
...
145+
146+
147+
class QuantizableBatchMatMul(nn.Module):
148+
"""
149+
Wrapper around torch.bmm with distinct inputs/output class
150+
instances that could be quantized through SparseML recipe
151+
"""
152+
153+
def __init__(self, left_input_cls, right_input_cls, output_cls):
154+
super().__init__()
155+
self.left_input = left_input_cls()
156+
self.right_input = right_input_cls()
157+
self.output = output_cls()
158+
159+
def forward(self, a: torch.Tensor, b: torch.Tensor):
160+
return self.output(torch.bmm(self.left_input(a), self.right_input(b)))
161+
162+
123163
class OPTAttention(nn.Module):
124164
"""Multi-headed attention from 'Attention Is All You Need' paper"""
125165

@@ -150,6 +190,9 @@ def __init__(
150190
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
151191
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
152192

193+
self.attn_weights_bmm = QuantizableBatchMatMul(BMMLeftInput_QK, BMMRightInput_QK, BMMOutput_QK)
194+
self.attn_output_bmm = QuantizableBatchMatMul(BMMLeftInput_PV, BMMRightInput_PV, BMMOutput_PV)
195+
153196
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
154197
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
155198

@@ -208,7 +251,7 @@ def forward(
208251
value_states = value_states.view(*proj_shape)
209252

210253
src_len = key_states.size(1)
211-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
254+
attn_weights = self.attn_weights_bmm(query_states, key_states.transpose(1, 2))
212255

213256
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
214257
raise ValueError(
@@ -254,7 +297,7 @@ def forward(
254297

255298
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
256299

257-
attn_output = torch.bmm(attn_probs, value_states)
300+
attn_output = self.attn_output_bmm(attn_probs, value_states)
258301

259302
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
260303
raise ValueError(

0 commit comments

Comments
 (0)