@@ -31,89 +31,13 @@ using Polyester
3131 end
3232end
3333
34- # @generated function nmuladd(A::VecUnroll{Nm1},B::AbstractStridedPointer,C::VecUnroll{Nm1}) where {Nm1}
35- # N = Nm1 + 1
36- # quote
37- # $(Expr(:meta,:inline))
38- # Ad = VectorizationBase.data(A);
39- # Cd = VectorizationBase.data(C);
40- # bp = stridedpointer(B)
41- # Base.Cartesian.@nexprs $N n -> C_n = Cd[n]
42- # Base.Cartesian.@nexprs $N k -> begin
43- # A_k = Ad[k]
44- # Base.Cartesian.@nexprs $N n -> begin
45- # C_n = Base.FastMath.sub_fast(C_n, Base.FastMath.mul_fast(A_k, vload(B, (k-1,n-1))))
46- # end
47- # end
48- # VecUnroll(Base.Cartesian.@ntuple $N C)
49- # end
50- # end
51-
52- # @inline function solve_Wx3W(A11::V, A12::V, A13::V, U::AbstractMatrix, ::StaticInt{W}) where {V<:VecUnroll,W}
53- # WS = StaticInt{W}()
54-
55- # U11 = view(U,StaticInt(1):WS,StaticInt(1):WS)
56- # A11 = solve_AU(A11, U11)
57-
58- # U12 = view(U,StaticInt(1):WS, StaticInt(1)+WS:WS*StaticInt(2))
59- # A12 = nmuladd(A11, U12, A12)
60- # U22 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS:WS*StaticInt(2))
61- # A12 = solve_AU(A12, U22)
62-
63- # U13 = view(U,StaticInt(1):WS, StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
64- # A13 = nmuladd(A11, U13, A13)
65- # U23 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
66- # A13 = nmuladd(A12, U23, A13)
67- # U33 = view(U,StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
68- # A13 = solve_AU(A13, U33)
69-
70- # return A11, A12, A13
71- # end
72-
73- # @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset) where {T}
74- # WS = VectorizationBase.pick_vector_width(T)
75- # W = Int(WS)
76- # A11 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
77- # A12 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
78- # A13 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
79-
80- # A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)
81-
82- # vstore!(ap, A11, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
83- # vstore!(ap, A12, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
84- # vstore!(ap, A13, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
85- # end
86- # @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset, m::VectorizationBase.AbstractMask) where {T}
87- # WS = VectorizationBase.pick_vector_width(T)
88- # W = Int(WS)
89- # A11 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
90- # A12 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
91- # A13 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
92-
93- # A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)
94-
95- # vstore!(ap, A11, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
96- # vstore!(ap, A12, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
97- # vstore!(ap, A13, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
98- # end
99-
100- # solve_3Wx3W!(A,B,U::UpperTriangular) = solve_3Wx3W!(A,B,parent(U))
101- # function solve_3Wx3W!(A::AbstractMatrix{T},B,U) where {T}
102- # W = VectorizationBase.pick_vector_width(T)
103- # ap = stridedpointer(A);
104- # bp = stridedpointer(B);
105- # solve_Wx3W!(ap, bp, U, StaticInt(1), StaticInt(1))
106- # solve_Wx3W!(ap, bp, U, StaticInt(1) + W, StaticInt(1))
107- # solve_Wx3W!(ap, bp, U, StaticInt(1) + W + W, StaticInt(1))
108- # end
109-
11034@inline maybestore! (p, v, i) = vstore! (p, v, i)
11135@inline maybestore! (:: Nothing , v, i) = nothing
11236
11337@inline maybestore! (p, v, i, m) = vstore! (p, v, i, m)
11438@inline maybestore! (:: Nothing , v, i, m) = nothing
11539
116- @inline function store_small_kern! (spa, sp, v, spu , i, n, mask, :: Val{true} )
40+ @inline function store_small_kern! (spa, sp, v, _ , i, n, mask, :: Val{true} )
11741 vstore! (spa, v, i, mask)
11842 vstore! (sp, v, i, mask)
11943end
16084 store_small_kern! (spa, sp, Amn, spu, Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ),n)), n, Val {UNIT} ())
16185 end
16286end
163- # function BdivU_small!(A::AbstractMatrix{T}, B::AbstractMatrix{T}, U::AbstractMatrix{T}) where {T}
164- # W = VectorizationBase.pick_vector_width(T)
165- # M, N = size(A)
166- # m = 0
167- # spa = stridedpointer(A)
168- # spb = stridedpointer(B)
169- # spu = stridedpointer(U)
170- # while m < M
171- # ml = m+1
172- # mu = m+W
173- # maskiter = mu > M
174- # mask = maskiter ? VectorizationBase.mask(W, M) : VectorizationBase.max_mask(W)
175- # for n ∈ 1:N
176- # Amn = vload(spb, (MM(W, ml),n), mask)
177- # for k ∈ 1:n-1
178- # Amn = vfnmadd_fast(vload(spa, (MM(W, ml),k), mask), vload(spu, (k,n)), Amn)
179- # end
180- # vstore!(spa, Amn / vload(spu, (n,n)), (MM(W, ml),n), mask)
181- # end
182- # m = mu
183- # end
184- # # @inbounds @fastmath for m ∈ 1:M
185- # # for n ∈ 1:N
186- # # Amn = B[m,n]
187- # # for k ∈ 1:n-1
188- # # Amn -= A[m,k]*U[k,n]
189- # # end
190- # # A[m,n] = Amn / U[n,n]
191- # # end
192- # # end
193- # end
194- # function nmuladd!(C,A,B,D)
195- # @turbo for n ∈ axes(C,2), m ∈ axes(C,1)
196- # Cmn = D[m,n]
197- # for k ∈ axes(B,1)
198- # Cmn -= A[m,k]*B[k,n]
199- # end
200- # C[m,n] = Cmn
201- # end
202- # end
20387
20488@generated function rdiv_solve_W_u! (spc, spb, spa, spu, n, :: StaticInt{W} , :: StaticInt{U} , :: Val{UNIT} ) where {W, U, UNIT}
20589 quote
@@ -286,7 +170,7 @@ const LDIVBUFFERS = Vector{UInt8}[]
286170 buff = LDIVBUFFERS[Threads. threadid ()]
287171 RSUF = StaticInt {UF} ()* VectorizationBase. register_size ()
288172 L = RSUF* N
289- L > length (buff) && resize! (buff, L)
173+ L > length (buff) && resize! (buff, L% UInt )
290174 ptr = Base. unsafe_convert (Ptr{T}, buff)
291175 si = StrideIndex {2,(1,2),1} ((VectorizationBase. static_sizeof (T), RSUF), (StaticInt (0 ),StaticInt (0 )))
292176 stridedpointer (ptr, si, StaticInt {0} ())
@@ -412,24 +296,14 @@ function rdiv_block_N!(
412296 N_temp = Core. ifelse (repeat, B_normalized, N)
413297 while true
414298 # println("Solve with N_temp = $N_temp and n = $n")
415- rdiv_U! (spc, spa_rdiv, gesp (spu, (n,StaticInt {0} ())), M, N_temp, StaticInt {X} (), Val ( UNIT))
299+ rdiv_U! (spc, spa_rdiv, gesp (spu, (n,StaticInt {0} ())), M, N_temp, StaticInt {X} (), Val { UNIT} ( ))
416300 repeat || break
417301 spa = gesp (spa, (StaticInt (0 ), B_normalized))
418302 spc = gesp (spc, (StaticInt (0 ), B_normalized))
419303 spu = gesp (spu, (StaticInt (0 ), B_normalized))
420- nnext = n + B_normalized
421- # N_temp =
422304 n += B_normalized
423305 repeat = n + B_normalized < N
424306 N_temp = repeat ? N_temp : N - n
425- # N_temp = min(n + B_normalized, N) - n
426- # println("nmuladd with N_temp = $N_temp and n = $n")
427- # mul!(
428- # copyto!(view(C, :, n+1:n+N_temp), view(A, :, n+1:n+N_temp)),
429- # view(C, :, 1:n),
430- # view(U, 1:n, n+1:n+N_temp),
431- # -1.0, 1.0
432- # )
433307 nmuladd! (spc_base, spa, spu, M, n, N_temp)
434308 spa_rdiv = spc
435309 end
@@ -439,15 +313,14 @@ function rdiv_block_MandN!(
439313) where {T,UNIT,X}
440314 B = block_size (Val (T))
441315 W = VectorizationBase. pick_vector_width (T)
442- B_normalized = VectorizationBase. vcld (N, VectorizationBase. vcld (N, B)* W)* W
443316 WUF = W* unroll_factor (W)
444317 B_m = VectorizationBase. vcld (M, VectorizationBase. vcld (M, B)* WUF)* WUF
445318 m = 0
446319 while m < M
447320 mu = m + B_m
448321 Mtemp = min (M, mu) - m
449322 rdiv_block_N! (
450- spc, spa, spu, Mtemp, N, Val ( UNIT), StaticInt {X} (),
323+ spc, spa, spu, Mtemp, N, Val { UNIT} ( ), StaticInt {X} (),
451324 VectorizationBase. vcld (N, VectorizationBase. vcld (N, B)* W)* W
452325 )
453326 spa = gesp (spa, (B_m, StaticInt {0} ()))
@@ -458,42 +331,33 @@ function rdiv_block_MandN!(
458331end
459332function m_thread_block_size (M, N, nthreads, :: Val{T} ) where {T}
460333 W = VectorizationBase. pick_vector_width (T)
461- WUF = W * unroll_factor (W)
462334 nb = clamp (VectorizationBase. vdiv (M * N, StaticInt {256} () * W), 1 , nthreads)
463335 min (M, VectorizationBase. vcld (M, nb* W)* W)
464336end
465337
338+ struct RDivBlockMandNv2{UNIT,X} end
339+ function (f:: RDivBlockMandNv2{UNIT,X} )(allargs, blockstart, blockstop) where {UNIT,X}
340+ spc, spa, spu, N, Mrem, Nblock, mtb = allargs
341+ for block = blockstart- 1 : blockstop- 1
342+ rdiv_block_MandN! (
343+ gesp (spc, (mtb* block, StaticInt {0} ())),
344+ gesp (spa, (mtb* block, StaticInt {0} ())),
345+ spu, Core. ifelse (block == Nblock- 1 , Mrem, mtb), N, Val {UNIT} (), static (X)
346+ )
347+ end
348+ end
349+
350+
466351function multithread_rdiv! (
467- spc:: AbstractStridedPointer{T } , spa, spu, M, N, mtb, :: Val{UNIT} , :: StaticInt{X}
468- ) where {X,T, UNIT}
469- mtb = 8
352+ spc:: AbstractStridedPointer{TC } , spa:: AbstractStridedPointer{TA} , spu:: AbstractStridedPointer{TU} , M:: Int , N:: Int , mtb:: Int , :: Val{UNIT} , :: StaticInt{X}
353+ ) where {X,UNIT,TC,TA,TU }
354+ # Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
470355 (Md, Mr) = VectorizationBase. vdivrem (M, mtb)
471356 Nblock = Md + (Mr ≠ 0 )
472357 Mrem = Core. ifelse (Mr ≠ 0 , Mr, mtb)
473- # @show mtb, Nblock, Mrem, Md, Mr
474- # return
475- let Md = Md, Mr = Mr, Nblock = Md + (Mr ≠ 0 ), Mrem = Core. ifelse (Mr ≠ 0 , Mr, mtb), VUNIT = Val {UNIT} (), StaticX = StaticInt {X} ()
476- @batch for block in CloseOpen (Nblock)
477- # for block in CloseOpen(Nblock)
478- # let block = 0
479- rdiv_block_MandN! (
480- # rdiv_block_N!(
481- gesp (spc, (mtb* block, StaticInt {0} ())),
482- gesp (spa, (mtb* block, StaticInt {0} ())),
483- spu, Core. ifelse (block == Nblock- 1 , Mrem, mtb), N, VUNIT, StaticX
484- # spu, M, N, Val{UNIT}(), StaticInt{X}()
485- )
486- end
487- end
358+ f = RDivBlockMandNv2 {UNIT,X} ()
359+ batch (f, (Nblock,min (Nblock,Threads. nthreads ())), spc, spa, spu, N, Mrem, Nblock, mtb)
488360 nothing
489- # nlaunch = Md - (Mr == 0)
490- # threads, torelease = Polyester.request_threads(Base.Threads.threadid(), nlaunch)
491- # nthread = length(threads)
492- # if (nthread % Int32) ≤ zero(Int32)
493- # return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT), StaticInt{X}())
494- # end
495- # nbatch = nthread + one(nthread)
496-
497361end
498362
499363# We're using `W x W` blocks, consuming `W` registers
@@ -521,7 +385,7 @@ function rdiv_U!(spc::AbstractStridedPointer{T}, spa::AbstractStridedPointer, sp
521385 if n > 0
522386 BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
523387 end
524- for i ∈ 1 : Nd
388+ for _ ∈ 1 : Nd
525389 rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
526390 n += W
527391 end
0 commit comments