@@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929#include "../common.h"
3030#define SGEMM BLASFUNC(sgemm)
3131#define SBGEMM BLASFUNC(sbgemm)
32+ #define SGEMV BLASFUNC(sgemv)
33+ #define SBGEMV BLASFUNC(sbgemv)
3234typedef union
3335{
3436 unsigned short v ;
@@ -187,7 +189,79 @@ main (int argc, char *argv[])
187189 free (CC );
188190 }
189191
190- if (ret != 0 )
192+ if (ret != 0 ) {
191193 fprintf (stderr , "FATAL ERROR SBGEMM - Return code: %d\n" , ret );
194+ return ret ;
195+ }
196+
197+ k = 1 ;
198+ for (x = 1 ; x <= loop ; x ++ )
199+ {
200+ float * A = (float * )malloc (x * x * sizeof (FLOAT ));
201+ float * B = (float * )malloc (x * sizeof (FLOAT ));
202+ float * C = (float * )malloc (x * sizeof (FLOAT ));
203+ bfloat16_bits * AA = (bfloat16_bits * )malloc (x * x * sizeof (bfloat16_bits ));
204+ bfloat16_bits * BB = (bfloat16_bits * )malloc (x * sizeof (bfloat16_bits ));
205+ float * DD = (float * )malloc (x * sizeof (FLOAT ));
206+ float * CC = (float * )malloc (x * sizeof (FLOAT ));
207+ if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
208+ (DD == NULL ) || (CC == NULL ))
209+ return 1 ;
210+ bfloat16 atmp , btmp ;
211+ blasint one = 1 ;
212+
213+ for (j = 0 ; j < x ; j ++ )
214+ {
215+ for (i = 0 ; i < x ; i ++ )
216+ {
217+ A [j * x + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
218+ sbstobf16_ (& one , & A [j * x + i ], & one , & atmp , & one );
219+ AA [j * x + i ].v = atmp ;
220+ }
221+ B [j ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
222+ sbstobf16_ (& one , & B [j ], & one , & btmp , & one );
223+ BB [j ].v = btmp ;
224+ }
225+ for (y = 0 ; y < 2 ; y ++ )
226+ {
227+ if (y == 0 ) {
228+ transA = 'N' ;
229+ } else {
230+ transA = 'T' ;
231+ }
232+
233+ memset (CC , 0 , x * sizeof (FLOAT ));
234+ memset (DD , 0 , x * sizeof (FLOAT ));
235+ memset (C , 0 , x * sizeof (FLOAT ));
236+
237+ SGEMV (& transA , & x , & x , & alpha , A , & x , B , & k , & beta , C , & k );
238+ SBGEMV (& transA , & x , & x , & alpha , (bfloat16 * ) AA , & x , (bfloat16 * ) BB , & k , & beta , CC , & k );
239+
240+ for (j = 0 ; j < x ; j ++ )
241+ for (i = 0 ; i < x ; i ++ )
242+ if (transA == 'N' ) {
243+ DD [i ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [j ]);
244+ } else if (transA == 'T' ) {
245+ DD [j ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [i ]);
246+ }
247+
248+ for (j = 0 ; j < x ; j ++ ) {
249+ if (fabs (CC [j ] - C [j ]) > 1.0 )
250+ ret ++ ;
251+ if (fabs (CC [j ] - DD [j ]) > 1.0 )
252+ ret ++ ;
253+ }
254+ }
255+ free (A );
256+ free (B );
257+ free (C );
258+ free (AA );
259+ free (BB );
260+ free (DD );
261+ free (CC );
262+ }
263+
264+ if (ret != 0 )
265+ fprintf (stderr , "FATAL ERROR SBGEMV - Return code: %d\n" , ret );
192266 return ret ;
193267}
0 commit comments