@@ -120,6 +120,46 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int
120120 return super ().forward (positions + self .offset )
121121
122122
123+ class BMMLeftInput_QK (nn .Identity ):
124+ ...
125+
126+
127+ class BMMRightInput_QK (nn .Identity ):
128+ ...
129+
130+
131+ class BMMOutput_QK (nn .Identity ):
132+ ...
133+
134+
135+ class BMMLeftInput_PV (nn .Identity ):
136+ ...
137+
138+
139+ class BMMRightInput_PV (nn .Identity ):
140+ ...
141+
142+
143+ class BMMOutput_PV (nn .Identity ):
144+ ...
145+
146+
147+ class QuantizableBatchMatMul (nn .Module ):
148+ """
149+ Wrapper around torch.bmm with distinct inputs/output class
150+ instances that could be quantized through SparseML recipe
151+ """
152+
153+ def __init__ (self , left_input_cls , right_input_cls , output_cls ):
154+ super ().__init__ ()
155+ self .left_input = left_input_cls ()
156+ self .right_input = right_input_cls ()
157+ self .output = output_cls ()
158+
159+ def forward (self , a : torch .Tensor , b : torch .Tensor ):
160+ return self .output (torch .bmm (self .left_input (a ), self .right_input (b )))
161+
162+
123163class OPTAttention (nn .Module ):
124164 """Multi-headed attention from 'Attention Is All You Need' paper"""
125165
@@ -150,6 +190,9 @@ def __init__(
150190 self .q_proj = nn .Linear (embed_dim , embed_dim , bias = bias )
151191 self .out_proj = nn .Linear (embed_dim , embed_dim , bias = bias )
152192
193+ self .attn_weights_bmm = QuantizableBatchMatMul (BMMLeftInput_QK , BMMRightInput_QK , BMMOutput_QK )
194+ self .attn_output_bmm = QuantizableBatchMatMul (BMMLeftInput_PV , BMMRightInput_PV , BMMOutput_PV )
195+
153196 def _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
154197 return tensor .view (bsz , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 ).contiguous ()
155198
@@ -208,7 +251,7 @@ def forward(
208251 value_states = value_states .view (* proj_shape )
209252
210253 src_len = key_states .size (1 )
211- attn_weights = torch . bmm (query_states , key_states .transpose (1 , 2 ))
254+ attn_weights = self . attn_weights_bmm (query_states , key_states .transpose (1 , 2 ))
212255
213256 if attn_weights .size () != (bsz * self .num_heads , tgt_len , src_len ):
214257 raise ValueError (
@@ -254,7 +297,7 @@ def forward(
254297
255298 attn_probs = nn .functional .dropout (attn_weights , p = self .dropout , training = self .training )
256299
257- attn_output = torch . bmm (attn_probs , value_states )
300+ attn_output = self . attn_output_bmm (attn_probs , value_states )
258301
259302 if attn_output .size () != (bsz * self .num_heads , tgt_len , self .head_dim ):
260303 raise ValueError (
0 commit comments