@@ -87,15 +87,6 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
8787 const auto wei_zero_points_d = ctx.memory_mdw (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
8888 int wei_scales_oc_stride = wei_scales_d.dims ()[0 ] > 1 ? 1 : 0 ;
8989 int wei_zero_points_oc_stride = wei_zero_points_d.dims ()[0 ] > 1 ? 1 : 0 ;
90- int wei_scales_ic_group_size, wei_zero_points_ic_group_size;
91- if (jbgp.with_grouped_weights_decompression ) {
92- int wei_scales_ic_group_num = wei_scales_d.dims ()[1 ];
93- int wei_zero_points_ic_group_num = wei_zero_points_d.dims ()[1 ];
94- wei_scales_ic_group_size = wei_scales_ic_group_num ? div_up (jbgp.ic , wei_scales_ic_group_num) : jbgp.ic ;
95- wei_zero_points_ic_group_size = wei_zero_points_ic_group_num ? div_up (jbgp.ic , wei_zero_points_ic_group_num) : jbgp.ic ;
96- } else {
97- wei_scales_ic_group_size = wei_zero_points_ic_group_size = jbgp.ic ;
98- }
9990
10091 const float *oscales = nullptr ;
10192 if (jbgp.weights_decompression ) {
@@ -170,8 +161,6 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
170161 const auto wei_ic_stride
171162 = types::data_type_size (jbgp.wei_dt ) * weights_d.off_v (ic_dims);
172163
173- int typesize_scale = one_of (jbgp.wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
174-
175164 const auto ker = [&](int ithr_oc_mb, int nthr_oc_mb, int ithr_ic, int osb,
176165 int osb_s, int ocb, int ocb_s, int icc, int icc_s,
177166 bool copy_buffer_a, int &prev_ker_idx) {
@@ -269,7 +258,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
269258 int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index (
270259 is_bs_tail, kernel_init, is_os_tail, is_oc_tail, false );
271260 auto brg_kernel = brg_kernels_[brg_ker_idx].get ();
272- const int ic_blocks_per_batch = jbgp.K / jbgp.ic_block ;
261+ const int ic_blocks_per_batch = div_up ( jbgp.K , jbgp.ic_block ) ;
273262 const dim_t wei_cur_ocb
274263 = get_blk_off (weights_d, jbgp.wei_dt , cur_ocb, 0 );
275264
@@ -290,7 +279,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
290279 ic + b * jbgp.K ));
291280 addr_batch[b].ptr .A = A_ptr;
292281 const dim_t wei_offset = (wei_cur_ocb
293- + wei_ic_stride * (icb + b * ic_blocks_per_batch)) / typesize_scale ;
282+ + wei_ic_stride * (icb + b * ic_blocks_per_batch));
294283 if (jbgp.weights_compressed ) {
295284 using comp_tile_len_type = int ;
296285 const comp_tile_len_type *compressed_tile_lengths_ptr
@@ -311,30 +300,35 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
311300 (*brg_decomp_kernel_)(&dcomp_params);
312301 addr_batch[b].ptr .B = decomp_buf;
313302 } else if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t ::prepack) {
314- auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt );
303+ int typesize_scale = one_of (jbgp.orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
304+ auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt ) / typesize_scale;
315305 auto weights_ptr = reinterpret_cast <const uint8_t *>(&weights[w_off]);
316306
317307 const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
318308 auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr + wei_ic_stride * b * ic_blocks_per_batch;
319309
320- const int ic_internal_block = is_amx ? 2 : 1 ;
321- auto wei_zero_points_ptr = wei_zero_points + oc;
322- auto wei_scales_ptr = wei_scales + oc;
310+ const int ic_internal_block = is_amx || one_of ( pd ()-> jbgp_ . orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
311+ auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
312+ auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
323313
324314 if (jbgp.with_grouped_weights_decompression ) {
325315 weights_decompression_runtime_params_t rt_params = {};
326316 auto ic_size = jbgp.ic_block * ic_blocks_per_batch / ic_internal_block;
327- auto wei_scales_ic_group_size_local = wei_scales_ic_group_size / ic_internal_block;
328- auto wei_zero_points_ic_group_size_local = wei_zero_points_ic_group_size / ic_internal_block;
317+ auto wei_scales_ic_group_size_local = jbgp. wei_scales_ic_group_size / ic_internal_block;
318+ auto wei_zero_points_ic_group_size_local = jbgp. wei_zero_points_ic_group_size / ic_internal_block;
329319 auto group_size = nstl::min (wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local);
330320 auto group_ic_blocks = div_up (ic_size, group_size);
321+ auto start_group_scales = ic / jbgp.wei_scales_ic_group_size ;
322+ auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size ;
331323 for (int icb_idx = 0 ; icb_idx < group_ic_blocks; icb_idx++) {
332324 auto ic_idx = icb_idx * group_size;
325+ auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales;
326+ auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points;
333327
334- rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt );
328+ rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt ) / typesize_scale ;
335329 rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block *jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
336- rt_params.scales_ptr = wei_scales_ptr + (ic_idx * wei_scales_d.dims ()[0 ]) / wei_scales_ic_group_size_local ;
337- rt_params.zero_points_ptr = wei_zero_points_ptr + (ic_idx * wei_zero_points_d.dims ()[0 ]) / wei_zero_points_ic_group_size_local ;
330+ rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims ()[0 ];
331+ rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims ()[0 ];
338332 rt_params.ic_size = nstl::min (group_size, ic_size - icb_idx * group_size);
339333 (*brg_weights_decomp_kernel_)(&rt_params);
340334 }
@@ -350,15 +344,16 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
350344
351345 addr_batch[b].ptr .B = decomp_buf;
352346 } else {
353- addr_batch[b].ptr .B = weights + wei_offset;
347+ int typesize_scale = one_of (jbgp.wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
348+ addr_batch[b].ptr .B = weights + wei_offset / typesize_scale;
354349 }
355350 }
356351
357352 int wei_scales_offset = 0 ;
358353 int wei_zero_points_offset = 0 ;
359354 if (jbgp.weights_decompression ) {
360- wei_scales_offset = (ic / wei_scales_ic_group_size) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
361- wei_zero_points_offset = (ic / wei_zero_points_ic_group_size) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
355+ wei_scales_offset = (ic / jbgp. wei_scales_ic_group_size ) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
356+ wei_zero_points_offset = (ic / jbgp. wei_zero_points_ic_group_size ) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
362357 }
363358
364359 auto ptr_D = dst + dst_off;
@@ -382,10 +377,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
382377
383378 brgemm_kernel_execute_postops (brg_kernel, gemm_batch,
384379 addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
385- scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
380+ scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
386381 } else {
387382 brgemm_kernel_execute (brg_kernel, gemm_batch, addr_batch,
388- (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
383+ (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
389384 }
390385 }
391386
@@ -403,33 +398,38 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
403398 + get_blk_off (src_d, jbgp.src_dt , n,
404399 ic + ic_block * jbgp.ic_block );
405400 const dim_t wei_offset
406- = (wei_cur_ocb + wei_ic_stride * (icb + ic_block)) / typesize_scale ;
401+ = (wei_cur_ocb + wei_ic_stride * (icb + ic_block));
407402
408403 if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t ::prepack) {
409- auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt );
404+ int typesize_scale = one_of (jbgp.orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
405+ auto w_off = wei_offset * types::data_type_size (jbgp.orig_wei_dt ) / types::data_type_size (jbgp.wei_dt ) / typesize_scale;
410406 auto weights_ptr = reinterpret_cast <const uint8_t *>(&weights[w_off]);
411407
412408 const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
413409 auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr;
414410
415- const int ic_internal_block = is_amx ? 2 : 1 ;
416- auto wei_zero_points_ptr = wei_zero_points + oc;
417- auto wei_scales_ptr = wei_scales + oc;
411+ const int ic_internal_block = is_amx || one_of ( pd ()-> jbgp_ . orig_wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
412+ auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
413+ auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
418414
419415 if (jbgp.with_grouped_weights_decompression ) {
416+ weights_decompression_runtime_params_t rt_params = {};
420417 auto ic_size = (jbgp.ic - (ic + ic_block * jbgp.ic_block )) / ic_internal_block;
421- auto wei_scales_ic_group_size_local = wei_scales_ic_group_size / ic_internal_block;
422- auto wei_zero_points_ic_group_size_local = wei_zero_points_ic_group_size / ic_internal_block;
418+ auto wei_scales_ic_group_size_local = jbgp. wei_scales_ic_group_size / ic_internal_block;
419+ auto wei_zero_points_ic_group_size_local = jbgp. wei_zero_points_ic_group_size / ic_internal_block;
423420 auto group_size = nstl::min (wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local);
424421 auto group_ic_blocks = div_up (ic_size, group_size);
425- weights_decompression_runtime_params_t rt_params = {};
422+ auto start_group_scales = ic / jbgp.wei_scales_ic_group_size ;
423+ auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size ;
426424 for (int icb_idx = 0 ; icb_idx < group_ic_blocks; icb_idx++) {
427425 auto ic_idx = icb_idx * group_size;
426+ auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales;
427+ auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points;
428428
429- rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt );
429+ rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.orig_wei_dt ) / typesize_scale ;
430430 rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size (jbgp.wei_dt );
431- rt_params.scales_ptr = wei_scales_ptr + (ic_idx * wei_scales_d.dims ()[0 ]) / wei_scales_ic_group_size_local ;
432- rt_params.zero_points_ptr = wei_zero_points_ptr + (ic_idx * wei_zero_points_d.dims ()[0 ]) / wei_zero_points_ic_group_size_local ;
431+ rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims ()[0 ];
432+ rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims ()[0 ];
433433 rt_params.ic_size = nstl::min (group_size, ic_size - icb_idx * group_size);
434434 (*brg_weights_decomp_kernel_)(&rt_params);
435435 }
@@ -445,14 +445,15 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
445445
446446 addr_batch[0 ].ptr .B = decomp_buf;
447447 } else {
448- addr_batch[0 ].ptr .B = weights + wei_offset;
448+ int typesize_scale = one_of (jbgp.wei_dt , data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1 ;
449+ addr_batch[0 ].ptr .B = weights + wei_offset / typesize_scale;
449450 }
450451
451452 int wei_scales_offset = 0 ;
452453 int wei_zero_points_offset = 0 ;
453454 if (jbgp.weights_decompression ) {
454- wei_scales_offset = (ic / wei_scales_ic_group_size) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
455- wei_zero_points_offset = (ic / wei_zero_points_ic_group_size) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
455+ wei_scales_offset = (ic / jbgp. wei_scales_ic_group_size ) * wei_scales_d.dims ()[0 ] + wei_scales_oc_stride * oc;
456+ wei_zero_points_offset = (ic / jbgp. wei_zero_points_ic_group_size ) * wei_zero_points_d.dims ()[0 ] + wei_zero_points_oc_stride * oc;
456457 }
457458
458459 auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get ();
@@ -474,10 +475,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
474475 nullptr , false , 1 , false , false , dst_scales};
475476
476477 brgemm_kernel_execute_postops (brg_kernel_ic_tail, 1 , addr_batch,
477- (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
478+ (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
478479 } else {
479480 brgemm_kernel_execute (brg_kernel_ic_tail, 1 , addr_batch,
480- (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic );
481+ (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr , &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0 );
481482 }
482483 }
483484 };
0 commit comments