@@ -261,19 +261,59 @@ end
261261 nothing
262262end
263263
264- const LDIVBUFFERS = Vector{UInt8}[]
265- @inline function lubuffer (:: Val{T} , :: StaticInt{UF} , N) where {T,UF}
266- buff = LDIVBUFFERS[Threads. threadid ()]
267- RSUF = StaticInt {UF} () * VectorizationBase. register_size ()
264+ const buffer = Ref {Ptr{Cvoid}} (C_NULL )
265+
266+ function __init__ ()
267+ bp_size = 2 * sizeof (Int) * Threads. nthreads ()
268+ buffer[] = bp = Libc. malloc (bp_size)
269+ Libc. memset (bp, 0 , bp_size)
270+ end
271+
272+ function _get_buffer_pointer (:: StaticInt{UF} , N) where {UF}
273+ RS = VectorizationBase. register_size ()
274+ RSUF = StaticInt {UF} () * RS
268275 L = RSUF * N
269- L > length (buff) && resize! (buff, L % UInt)
270- ptr = Base. unsafe_convert (Ptr{T}, pointer (buff))
276+ tid = Threads. threadid () - 1
277+ bp = Ptr {Pair{Ptr{Cvoid},Int}} (buffer[]) + 2 sizeof (Int) * tid
278+ (p, buff_current) = unsafe_load (bp)
279+ if buff_current < L
280+ p == C_NULL || Libc. free (p)
281+ buff_size = max (RSUF * 128 , L)
282+ p = Libc. malloc (Int (buff_size + RS - 1 ))
283+ unsafe_store! (bp, p => buff_size)
284+ end
285+ return VectorizationBase. align (p, RS)
286+ end
287+
288+ @inline function lubuffer (:: Val{T} , :: StaticInt{UF} , N) where {T,UF}
289+ RS = VectorizationBase. register_size ()
290+ RSUF = StaticInt {UF} () * RS
291+ ptr = Ptr {T} (_get_buffer_pointer (StaticInt {UF} (), N))
271292 si = StrideIndex {2,(1, 2),1} (
272293 (VectorizationBase. static_sizeof (T), RSUF),
273294 (StaticInt (0 ), StaticInt (0 ))
274295 )
275- stridedpointer (ptr, si, StaticInt {0} ())
296+ stridedpointer (ptr, si, StaticInt {0} ()), nothing
297+ end
298+ @inline function lubuffer (
299+ :: Val{T} ,
300+ :: StaticInt{UF} ,
301+ :: StaticInt{N}
302+ ) where {T,UF,N}
303+ RSUF = StaticInt {UF} () * VectorizationBase. pick_vector_width (T)
304+ L = RSUF * N
305+ buf = Ref {NTuple{L,T}} ()
306+ ptr = Base. unsafe_convert (Ptr{T}, buf)
307+ si = StrideIndex {2,(1, 2),1} (
308+ (
309+ VectorizationBase. static_sizeof (T),
310+ RSUF * VectorizationBase. static_sizeof (T)
311+ ),
312+ (StaticInt (0 ), StaticInt (0 ))
313+ )
314+ stridedpointer (ptr, si, StaticInt {0} ()), buf
276315end
316+ @inline _free (p:: Ptr ) = Libc. free (p)
277317_canonicalize (x) = signed (x)
278318_canonicalize (:: StaticInt{N} ) where {N} = StaticInt {N} ()
279319function div_dispatch! (
@@ -528,12 +568,12 @@ function block_size(::Val{T}) where {T}
528568end
529569
530570nmuladd! (C, A, U, M, K, N) = @turbo for n ∈ CloseOpen (N), m ∈ CloseOpen (M)
531- Cmn = A[m, n]
532- for k ∈ CloseOpen (K)
533- Cmn -= C[m, k] * U[k, n]
534- end
535- C[m, K+ n] = Cmn
571+ Cmn = A[m, n]
572+ for k ∈ CloseOpen (K)
573+ Cmn -= C[m, k] * U[k, n]
536574 end
575+ C[m, K+ n] = Cmn
576+ end
537577
538578function rdiv_block_N! (
539579 spc:: AbstractStridedPointer{T} ,
@@ -695,50 +735,44 @@ function rdiv_U!(
695735 WU = UF * WS
696736 MU = UF > 1 ? M : 0
697737 Nd, Nr = VectorizationBase. vdivrem (N, WS)
698- spb = lubuffer (Val (T), UF, N)
738+ spb, preserve = lubuffer (Val (T), UF, N)
699739 m = 0
700- while m < MU - WU + 1
701- n = Nr
702- if n > 0
703- BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
704- end
705- for _ ∈ 1 : Nd
706- rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
707- n += W
708- end
709- m += WU
710- spa = gesp (spa, (WU, StaticInt (0 )))
711- spc = gesp (spc, (WU, StaticInt (0 )))
712- end
713- finalmask = VectorizationBase. mask (WS, M)
714- while m < M
715- ubm = m + W
716- nomaskiter = ubm < M
717- mask = nomaskiter ? VectorizationBase. max_mask (WS) : finalmask
718- n = Nr
719- if n > 0
720- BdivU_small_kern! (spb, spc, spa, spu, n, mask, Val (UNIT))
740+ GC. @preserve preserve begin
741+ while m < MU - WU + 1
742+ n = Nr
743+ if n > 0
744+ BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
745+ end
746+ for _ ∈ 1 : Nd
747+ rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
748+ n += W
749+ end
750+ m += WU
751+ spa = gesp (spa, (WU, StaticInt (0 )))
752+ spc = gesp (spc, (WU, StaticInt (0 )))
721753 end
722- for i ∈ 1 : Nd
723- # @show C, n
724- rdiv_solve_W! (spb, spc, spa, spu, n, i ≠ Nd, mask, Val (UNIT))
725- n += W
754+ finalmask = VectorizationBase. mask (WS, M)
755+ while m < M
756+ ubm = m + W
757+ nomaskiter = ubm < M
758+ mask = nomaskiter ? VectorizationBase. max_mask (WS) : finalmask
759+ n = Nr
760+ if n > 0
761+ BdivU_small_kern! (spb, spc, spa, spu, n, mask, Val (UNIT))
762+ end
763+ for i ∈ 1 : Nd
764+ # @show C, n
765+ rdiv_solve_W! (spb, spc, spa, spu, n, i ≠ Nd, mask, Val (UNIT))
766+ n += W
767+ end
768+ spa = gesp (spa, (WS, StaticInt (0 )))
769+ spc = gesp (spc, (WS, StaticInt (0 )))
770+ m = ubm
726771 end
727- spa = gesp (spa, (WS, StaticInt (0 )))
728- spc = gesp (spc, (WS, StaticInt (0 )))
729- m = ubm
730772 end
731773 nothing
732774end
733775
734- function __init__ ()
735- nthread = Threads. nthreads ()
736- resize! (LDIVBUFFERS, nthread)
737- for i ∈ 1 : nthread
738- LDIVBUFFERS[i] =
739- Vector {UInt8} (undef, 3 VectorizationBase. register_size () * 128 )
740- end
741- end
742776#=
743777using PrecompileTools
744778@static if VERSION >= v"1.8.0-beta1"
0 commit comments