1111
1212#include " arch-fallback.h"
1313
14+ #include < algorithm>
1415#include < cmath>
1516#include < cstring>
1617#include < cassert>
@@ -1600,29 +1601,48 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16001601 return false ;
16011602 }
16021603
1603- void forward_mul_mat_one_chunk (ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
1604+ void forward_mul_mat_one_chunk (ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end, int64_t src1_start, int64_t src1_end ) {
16041605 const ggml_tensor * src0 = op->src [0 ];
16051606 const ggml_tensor * src1 = op->src [1 ];
16061607 ggml_tensor * dst = op;
16071608
16081609 GGML_TENSOR_BINARY_OP_LOCALS
16091610
1610- const void * src1_wdata = params->wdata ;
16111611 const size_t src1_col_stride = ggml_row_size (PARAM_TYPE, ne10);
16121612
1613+ GGML_ASSERT (ne03 == 1 && ne13 == 1 );
1614+ GGML_ASSERT (ne12 % ne02 == 0 );
1615+ const int64_t r2 = ne12 / ne02;
1616+
1617+ const int64_t i12 = src1_start / ne1;
1618+ const int64_t i11 = src1_start - i12 * ne1;
1619+
1620+ // Determine batch index
1621+ const int64_t i02 = i12 / r2;
1622+
1623+ const int64_t i1 = i11;
1624+ const int64_t i2 = i12;
1625+
1626+ const char *src0_ptr = (const char *)src0->data + i02 * nb02;
1627+ const char *src1_ptr = (const char *)params->wdata + (i11 + i12 * ne11) * src1_col_stride;
1628+ float *dst_ptr = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2));
1629+
1630+ const int64_t nrows = src1_end - src1_start;
1631+ const int64_t ncols = src0_end - src0_start;
1632+
16131633 // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1614- if (ne11 > 3 ) {
1634+ if (nrows > 3 ) {
16151635 gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616- ( float *) (( char *) dst-> data ) + src0_start, ne01 ,
1617- ( const char *) src0-> data + src0_start * nb01,
1618- ( const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start );
1636+ dst_ptr + src0_start, nb1 / nb0 ,
1637+ src0_ptr + src0_start * nb01,
1638+ src1_ptr, nrows - (nrows % 4 ), ncols );
16191639 }
1620- for (int iter = ne11 - ne11 % 4 ; iter < ne11 ; iter++) {
1640+ for (int iter = nrows - (nrows % 4 ) ; iter < nrows ; iter++) {
16211641 gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622- ( float *) (( char *) dst-> data + (iter * nb1) ) + src0_start, ne01,
1623- ( const char *) src0-> data + src0_start * nb01,
1624- ( const char *) src1_wdata + (src1_col_stride * iter), 1 ,
1625- src0_end - src0_start );
1642+ dst_ptr + (iter * nb1) + src0_start, ne01,
1643+ src0_ptr + src0_start * nb01,
1644+ src1_ptr + (src1_col_stride * iter), 1 /* nrows */ ,
1645+ ncols );
16261646 }
16271647 }
16281648
@@ -1647,54 +1667,72 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16471667 GGML_ASSERT (nb1 <= nb2);
16481668 GGML_ASSERT (nb2 <= nb3);
16491669
1670+ // TODO: General batched mul mat for 4D tensors
1671+ // Currently only supports 3D tensors
1672+ GGML_ASSERT (ne13 == 1 );
1673+
16501674 GGML_ASSERT (src1->type == GGML_TYPE_F32);
16511675
16521676 GGML_ASSERT (ggml_n_dims (op->src [0 ]) == 2 );
16531677 // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
16541678
16551679 char * wdata = static_cast <char *>(params->wdata );
1656- const size_t nbw1 = ggml_row_size (PARAM_TYPE, ne10);
1680+ const size_t nbw1 = ggml_row_size (PARAM_TYPE, ne10);
1681+ const size_t nbw2 = nbw1 * ne11;
16571682
1658- assert (params->wsize >= nbw1 * ne11 );
1683+ assert (params->wsize >= nbw2 * ne12 );
16591684
16601685 const ggml_from_float_t from_float = ggml_get_type_traits_cpu (PARAM_TYPE)->from_float ;
16611686
1662- int64_t i11_processed = 0 ;
1663- for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
1664- ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10);
1665- }
1687+ for (int64_t i12 = 0 ; i12 < ne12; i12++) {
1688+ char * data_ptr = (char *) src1->data + i12 * nb12;
1689+ char * wdata_ptr = wdata + i12 * nbw2;
1690+
1691+ for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
1692+ ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
1693+ (void *) (wdata_ptr + i11 * nbw1), 4 , ne10);
1694+ }
16661695
1667- i11_processed = ne11 - ne11 % 4 ;
1668- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1669- from_float ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
1696+ const int64_t i11_processed = ne11 - ne11 % 4 ;
1697+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1698+ from_float ((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
1699+ }
16701700 }
16711701
16721702 // disable for NUMA
16731703 const bool disable_chunking = ggml_is_numa ();
16741704
16751705 // 4x chunks per thread
1676- int64_t nr = ggml_nrows (op->src [0 ]);
1677- int nth_scaled = nth * 4 ;
1678- int64_t chunk_size = (nr + nth_scaled - 1 ) / nth_scaled;
1679- int64_t nchunk = (nr + chunk_size - 1 ) / chunk_size;
1706+ const int64_t nr0 = ggml_nrows (op->src [0 ]);
1707+ const int64_t nr1 = ne1 * ne2 * ne3;
1708+
1709+ int nth_scaled = nth * 4 ;
1710+ int64_t chunk_size0 = (nr0 + nth_scaled - 1 ) / nth_scaled;
1711+ // avoid too small chunks for narrow src1
1712+ int64_t chunk_size1 = std::max<int64_t >(16 , (nr1 + nth - 1 ) / nth);
1713+ int64_t nchunk0 = (nr0 + chunk_size0 - 1 ) / chunk_size0;
1714+ int64_t nchunk1 = (nr1 + chunk_size1 - 1 ) / chunk_size1;
16801715
16811716 // Ensure minimum chunk size to avoid alignment issues with high thread counts
16821717 // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
16831718 const int64_t min_chunk_size = NB_COLS;
1684- if (nchunk > 0 && (nr / nchunk ) < min_chunk_size && nr >= min_chunk_size) {
1685- nchunk = (nr + min_chunk_size - 1 ) / min_chunk_size;
1719+ if (nchunk0 > 0 && (nr0 / nchunk0 ) < min_chunk_size && nr0 >= min_chunk_size) {
1720+ nchunk0 = (nr0 + min_chunk_size - 1 ) / min_chunk_size;
16861721 }
16871722
1688- if (nth == 1 || nchunk < nth || disable_chunking) {
1689- nchunk = nth;
1723+
1724+ if (nth == 1 || nchunk0 * nchunk1 < nth || disable_chunking) {
1725+ nchunk0 = nr0 > nr1 ? nth : 1 ;
1726+ nchunk1 = nr0 > nr1 ? 1 : nth;
16901727 }
16911728
1729+ const int64_t dr0 = (nr0 + nchunk0 - 1 ) / nchunk0;
1730+ const int64_t dr1 = (nr1 + nchunk1 - 1 ) / nchunk1;
1731+
16921732 // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
16931733 // This prevents creating too many tiny chunks that could overlap after alignment
1694- const int64_t max_nchunk = (nr + min_chunk_size - 1 ) / min_chunk_size;
1695- if (nchunk > max_nchunk) {
1696- nchunk = max_nchunk;
1697- }
1734+ const int64_t max_nchunk = (nr0 + min_chunk_size - 1 ) / min_chunk_size;
1735+ nchunk0 = std::min (nchunk0, max_nchunk);
16981736
16991737 if (ith == 0 ) {
17001738 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@@ -1706,23 +1744,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
17061744 // The first chunk comes from our thread_id, the rest will get auto-assigned.
17071745 int current_chunk = ith;
17081746
1709- while (current_chunk < nchunk) {
1710- int64_t src0_start = (current_chunk * ne01) / nchunk;
1711- int64_t src0_end = ((current_chunk + 1 ) * ne01) / nchunk;
1747+ while (current_chunk < nchunk0 * nchunk1) {
1748+ const int64_t ith0 = current_chunk % nchunk0; // rows chunk
1749+ const int64_t ith1 = current_chunk / nchunk0; // (N * batch) chunk
1750+
1751+ int64_t src0_start = dr0 * ith0;
1752+ int64_t src0_end = MIN (src0_start + dr0, nr0);
1753+
1754+ int64_t src1_start = dr1 * ith1;
1755+ int64_t src1_end = MIN (src1_start + dr1, nr1);
17121756
17131757 // Align boundaries to NB_COLS - round up to ensure all data is included
17141758 // The chunk size limiting above ensures chunks are large enough to prevent overlaps
17151759 src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
17161760 src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1717- if (src0_end > ne01) {
1718- src0_end = ne01;
1719- }
1761+ src0_end = std::min (src0_end, ne01);
17201762
1763+ // Make sure current plane is the last one before exiting
17211764 if (src0_start >= src0_end) {
1722- break ;
1765+ if (nth >= nchunk0 * nchunk1) {
1766+ break ;
1767+ }
1768+ current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
1769+ continue ;
17231770 }
17241771
1725- forward_mul_mat_one_chunk (params, dst, src0_start, src0_end);
1772+ forward_mul_mat_one_chunk (params, dst, src0_start, src0_end, src1_start, src1_end );
17261773
17271774 current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
17281775 }
0 commit comments