489489 spu:: AbstractStridedPointer ,
490490 M,
491491 N,
492- :: StaticInt{1} ,
493- :: StaticInt{1} ,
494492 :: Val{UNIT}
495493) where {T,UNIT}
496494 WS = pick_vector_width (T)
534532 nothing
535533end
536534
537- const buffer = Ref {Ptr{Cvoid}} (C_NULL )
538-
539- function __init__ ()
540- bp_size = 2 * sizeof (Int) * Threads. nthreads ()
541- buffer[] = bp = Libc. malloc (bp_size % UInt)
542- Libc. memset (bp, 0 , bp_size)
543- end
544-
545- function _get_buffer_pointer (:: StaticInt{UF} , N) where {UF}
546- RS = VectorizationBase. register_size ()
547- RSUF = StaticInt {UF} () * RS
548- L = RSUF * N
549- tid = Threads. threadid () - 1
550- bp = Ptr {Pair{Ptr{Cvoid},Int}} (buffer[]) + 2 sizeof (Int) * tid
551- (p, buff_current) = unsafe_load (bp)
552- if buff_current < L
553- p == C_NULL || Libc. free (p)
554- buff_size = max (RSUF * 128 , L)
555- p = Libc. malloc ((buff_size + RS - 1 ) % UInt)
556- unsafe_store! (bp, p => buff_size)
557- end
558- return VectorizationBase. align (p, RS)
559- end
560-
561- @inline function lubuffer (:: Val{T} , :: StaticInt{UF} , N) where {T,UF}
562- RS = VectorizationBase. register_size ()
563- RSUF = StaticInt {UF} () * RS
564- ptr = Ptr {T} (_get_buffer_pointer (StaticInt {UF} (), N))
565- si = StrideIndex {2,(1, 2),1} (
566- (VectorizationBase. static_sizeof (T), RSUF),
567- (StaticInt (0 ), StaticInt (0 ))
568- )
569- stridedpointer (ptr, si, StaticInt {0} ()), nothing
570- end
571- @inline function lubuffer (
572- :: Val{T} ,
573- :: StaticInt{UF} ,
574- :: StaticInt{N}
575- ) where {T,UF,N}
576- RSUF = StaticInt {UF} () * VectorizationBase. pick_vector_width (T)
577- L = RSUF * N
578- buf = Ref {NTuple{L,T}} ()
579- ptr = Base. unsafe_convert (Ptr{T}, buf)
580- si = StrideIndex {2,(1, 2),1} (
581- (
582- VectorizationBase. static_sizeof (T),
583- RSUF * VectorizationBase. static_sizeof (T)
584- ),
585- (StaticInt (0 ), StaticInt (0 ))
586- )
587- stridedpointer (ptr, si, StaticInt {0} ()), buf
588- end
589- @inline _free (p:: Ptr ) = Libc. free (p)
590535_canonicalize (x) = signed (x)
591536_canonicalize (:: StaticInt{N} ) where {N} = StaticInt {N} ()
592537function div_dispatch! (
@@ -606,17 +551,14 @@ function div_dispatch!(
606551 spa = zero_offsets (_spa)
607552 spc = zero_offsets (_spc)
608553 spu = zero_offsets (_spu)
609- XC = VectorizationBase. contiguous_axis (C)
610- XA = VectorizationBase. contiguous_axis (A)
611554 GC. @preserve spap spcp spup begin
612555 mtb = m_thread_block_size (M, N, nthread, Val (T))
613556 if nthread > 1
614- (M > mtb) &&
615- return multithread_rdiv! (spc, spa, spu, M, N, mtb, Val (UNIT), XC, XA)
557+ (M > mtb) && return multithread_rdiv! (spc, spa, spu, M, N, mtb, Val (UNIT))
616558 elseif N > block_size (Val (T))
617- return rdiv_block_MandN! (spc, spa, spu, M, N, Val (UNIT), XC, XA )
559+ return rdiv_block_MandN! (spc, spa, spu, M, N, Val (UNIT))
618560 end
619- return rdiv_U! (spc, spa, spu, M, N, XC, XA, Val (UNIT))
561+ return rdiv_U! (spc, spa, spu, M, N, Val (UNIT))
620562 end
621563end
622564
@@ -833,10 +775,8 @@ function rdiv_block_N!(
833775 M,
834776 N,
835777 :: Val{UNIT} ,
836- :: StaticInt{XC} ,
837- :: StaticInt{XA} ,
838778 Bsize = nothing
839- ) where {T,UNIT,XC,XA }
779+ ) where {T,UNIT}
840780 spa_rdiv = spa
841781 spc_base = spc
842782 n = 0
@@ -857,8 +797,6 @@ function rdiv_block_N!(
857797 gesp (spu, (n, StaticInt {0} ())),
858798 M,
859799 N_temp,
860- StaticInt {XC} (),
861- StaticInt {XA} (),
862800 Val {UNIT} ()
863801 )
864802 repeat || break
@@ -873,18 +811,16 @@ function rdiv_block_N!(
873811 end
874812end
875813function rdiv_block_MandN! (
876- spc:: AbstractStridedPointer{T} ,
877- spa,
878- spu,
814+ spc:: AbstractStridedPointer{T,<:Any,XC } ,
815+ spa:: AbstractStridedPointer{T,<:Any,XA} ,
816+ spu:: AbstractStridedPointer{T,<:Any,XU} ,
879817 M,
880818 N,
881- :: Val{UNIT} ,
882- :: StaticInt{XC} ,
883- :: StaticInt{XA}
884- ) where {T,UNIT,XC,XA}
819+ :: Val{UNIT}
820+ ) where {T,UNIT,XC,XA,XU}
885821 B = block_size (Val (T))
886822 W = VectorizationBase. pick_vector_width (T)
887- WUF = XC == XA == 2 ? W : W * unroll_factor (W)
823+ WUF = XC == XA == XA == 2 ? W : W * unroll_factor (W)
888824 B_m = VectorizationBase. vcld (M, VectorizationBase. vcld (M, B) * WUF) * WUF
889825 m = 0
890826 while m < M
@@ -897,8 +833,6 @@ function rdiv_block_MandN!(
897833 Mtemp,
898834 N,
899835 Val {UNIT} (),
900- StaticInt {XC} (),
901- StaticInt {XA} (),
902836 VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) * W) * W
903837 )
904838 spa = gesp (spa, (B_m, StaticInt {0} ()))
@@ -913,12 +847,12 @@ function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T}
913847 min (M, VectorizationBase. vcld (M, nb * W) * W)
914848end
915849
916- struct RDivBlockMandNv2{UNIT,XC,XA } end
917- function (f:: RDivBlockMandNv2{UNIT,XC,XA } )(
850+ struct RDivBlockMandNv2{UNIT} end
851+ function (f:: RDivBlockMandNv2{UNIT} )(
918852 allargs,
919853 blockstart,
920854 blockstop
921- ) where {UNIT,XC,XA }
855+ ) where {UNIT}
922856 spc, spa, spu, N, Mrem, Nblock, mtb = allargs
923857 for block = blockstart- 1 : blockstop- 1
924858 rdiv_block_MandN! (
@@ -927,9 +861,7 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})(
927861 spu,
928862 Core. ifelse (block == Nblock - 1 , Mrem, mtb),
929863 N,
930- Val {UNIT} (),
931- static (XC),
932- static (XA)
864+ Val {UNIT} ()
933865 )
934866 end
935867end
@@ -941,17 +873,14 @@ function multithread_rdiv!(
941873 M:: Int ,
942874 N:: Int ,
943875 mtb:: Int ,
944- :: Val{UNIT} ,
945- :: StaticInt{XC} ,
946- :: StaticInt{XA}
947- ) where {XC,XA,UNIT,TC,TA,TU}
876+ :: Val{UNIT}
877+ ) where {UNIT,TC,TA,TU}
948878 # Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
949879 (Md, Mr) = VectorizationBase. vdivrem (M, mtb)
950880 Nblock = Md + (Mr ≠ 0 )
951881 Mrem = Core. ifelse (Mr ≠ 0 , Mr, mtb)
952- f = RDivBlockMandNv2 {UNIT,XC,XA} ()
953882 batch (
954- f ,
883+ RDivBlockMandNv2 {UNIT} () ,
955884 (Nblock, min (Nblock, Threads. nthreads ())),
956885 spc,
957886 spa,
@@ -977,60 +906,6 @@ function unroll_factor(::StaticInt{W}) where {W}
977906 ifelse (Static. lt (num_blocks, StaticInt {1} ()), StaticInt {1} (), num_blocks)
978907end
979908
980- function rdiv_U! (
981- spc:: AbstractStridedPointer{T} ,
982- spa:: AbstractStridedPointer ,
983- spu:: AbstractStridedPointer ,
984- M,
985- N,
986- :: StaticInt{var"#UNUSED1#"} ,
987- :: StaticInt{var"#UNUSED2#"} ,
988- :: Val{UNIT}
989- ) where {T,UNIT,var"#UNUSED1#" ,var"#UNUSED2#" }
990- WS = pick_vector_width (T)
991- W = Int (WS)
992- UF = unroll_factor (WS)
993- WU = UF * WS
994- Nd, Nr = VectorizationBase. vdivrem (N, WS)
995- spb, preserve = lubuffer (Val (T), UF, N)
996- m = 0
997- GC. @preserve preserve begin
998- if UF > 1
999- while m < M - WU + 1
1000- n = Nr
1001- if n > 0
1002- BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT), WS)
1003- end
1004- for _ ∈ 1 : Nd
1005- rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
1006- n += W
1007- end
1008- m += WU
1009- spa = gesp (spa, (WU, StaticInt (0 )))
1010- spc = gesp (spc, (WU, StaticInt (0 )))
1011- end
1012- end
1013- finalmask = VectorizationBase. mask (WS, M)
1014- while m < M
1015- ubm = m + W
1016- nomaskiter = ubm < M
1017- mask = nomaskiter ? VectorizationBase. max_mask (WS) : finalmask
1018- n = Nr
1019- if n > 0
1020- BdivU_small_kern! (spb, spc, spa, spu, n, mask, Val (UNIT))
1021- end
1022- for i ∈ 1 : Nd
1023- rdiv_solve_W! (spb, spc, spa, spu, n, i ≠ Nd, mask, Val (UNIT))
1024- n += W
1025- end
1026- spa = gesp (spa, (WS, StaticInt (0 )))
1027- spc = gesp (spc, (WS, StaticInt (0 )))
1028- m = ubm
1029- end
1030- end
1031- nothing
1032- end
1033-
1034909@generated function _ldiv_remainder! (
1035910 spc,
1036911 spa,
@@ -1109,34 +984,50 @@ end
1109984) where {W,UNIT}
1110985 WS = static (W)
1111986 # US = static(U)
1112- quote
1113- # $(Expr(:meta, :inline))
1114- Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w -> _ldiv_remainder! (
1115- spc,
1116- spa,
1117- spu,
1118- M,
1119- N,
1120- m,
1121- Nr,
1122- $ WS,
1123- $ (Val (UNIT)),
1124- StaticInt (w)
1125- )
987+ if W == 2
988+ quote
989+ $ (Expr (:meta , :inline ))
990+ _ldiv_remainder! (
991+ spc,
992+ spa,
993+ spu,
994+ M,
995+ N,
996+ m,
997+ Nr,
998+ $ WS,
999+ $ (Val (UNIT)),
1000+ $ (static (1 ))
1001+ )
1002+ end
1003+ else
1004+ quote
1005+ # $(Expr(:meta, :inline))
1006+ Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w -> _ldiv_remainder! (
1007+ spc,
1008+ spa,
1009+ spu,
1010+ M,
1011+ N,
1012+ m,
1013+ Nr,
1014+ $ WS,
1015+ $ (Val (UNIT)),
1016+ StaticInt (w)
1017+ )
1018+ end
11261019 end
11271020end
11281021
11291022# spc = spa / spu
11301023# spc' = (spu' \ spa')'
11311024# This is ldiv
11321025function rdiv_U! (
1133- spc:: AbstractStridedPointer{T} ,
1134- spa:: AbstractStridedPointer ,
1135- spu:: AbstractStridedPointer ,
1026+ spc:: AbstractStridedPointer{T,2,2 } ,
1027+ spa:: AbstractStridedPointer{T,2,2} ,
1028+ spu:: AbstractStridedPointer{T,2,2} ,
11361029 M,
11371030 N,
1138- :: StaticInt{2} ,
1139- :: StaticInt{2} ,
11401031 :: Val{UNIT}
11411032) where {T,UNIT}
11421033 WS = pick_vector_width (T)
0 commit comments