11module TriangularSolve
2- using Base: @nexprs
2+ using Base: @nexprs , @ntuple
33if isdefined (Base, :Experimental ) &&
44 isdefined (Base. Experimental, Symbol (" @max_methods" ))
55 @eval Base. Experimental. @max_methods 1
5353@inline maybestore! (p, v, i, m) = vstore! (p, v, i, m)
5454@inline maybestore! (:: Nothing , v, i, m) = nothing
5555
56- @inline function store_small_kern! (spa, sp, v, _, i, n, mask, :: Val{true} )
56+ @inline function store_small_kern! (spa, sp, v, i, mask)
5757 vstore! (spa, v, i, mask)
5858 vstore! (sp, v, i, mask)
5959end
60- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, mask, :: Val{true} ) =
61- vstore! (spa, v, i, mask)
62-
63- @inline function store_small_kern! (spa, sp, v, spu, i, n, mask, :: Val{false} )
64- x = v / vload (spu, (n, n))
65- vstore! (spa, x, i, mask)
66- vstore! (sp, x, i, mask)
67- end
68- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, mask, :: Val{false} ) =
69- vstore! (spa, v / vload (spu, (n, n)), i, mask)
60+ @inline store_small_kern! (spa, :: Nothing , v, i, mask) = vstore! (spa, v, i, mask)
7061
71- @inline function store_small_kern! (spa, sp, v, spu, i, n, :: Val{true} )
62+ @inline function store_small_kern! (spa, sp, v, i )
7263 vstore! (spa, v, i)
7364 vstore! (sp, v, i)
7465end
75- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, :: Val{true} ) =
76- vstore! (spa, v, i)
66+ @inline store_small_kern! (spa, :: Nothing , v, i) = vstore! (spa, v, i)
7767
78- @inline function store_small_kern! (spa, sp, v, spu, i, n, :: Val{false} )
79- x = v / vload (spu, (n, n))
80- vstore! (spa, x, i)
81- vstore! (sp, x, i)
82- end
83- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, :: Val{false} ) =
84- vstore! (spa, v / vload (spu, (n, n)), i)
85-
86- @inline function BdivU_small_kern! (
68+ @generated function BdivU_small_kern! (
8769 spa:: AbstractStridedPointer{T} ,
8870 sp,
8971 spb:: AbstractStridedPointer{T} ,
9072 spu:: AbstractStridedPointer{T} ,
91- N ,
73+ :: StaticInt{N} ,
9274 mask:: AbstractMask{W} ,
9375 :: Val{UNIT}
94- ) where {T,UNIT,W}
95- # W = VectorizationBase.pick_vector_width(T)
96- for n ∈ CloseOpen (N)
97- Amn = vload (spb, (MM {W} (StaticInt (0 )), n), mask)
98- for k ∈ SafeCloseOpen (n)
99- Amn = vfnmadd_fast (
100- vload (spa, (MM {W} (StaticInt (0 )), k), mask),
101- vload (spu, (k, n)),
102- Amn
103- )
76+ ) where {T,UNIT,W,N}
77+ z = static (0 )
78+ if N == 1
79+ i = (MM {W} (z), z)
80+ Amn = :(vload (spb, $ i, mask))
81+ if ! UNIT
82+ Amn = :($ Amn / vload (spu, $ ((z, z))))
83+ end
84+ quote
85+ $ (Expr (:meta , :inline ))
86+ store_small_kern! (spa, sp, $ Amn, $ i, mask)
87+ end
88+ else
89+ unroll = Unroll {2,1,N,1,W,(-1 % UInt),1} ((z, z))
90+ tostore = :(VecUnroll (Base. Cartesian. @ntuple $ N Amn))
91+ scale = UNIT ? nothing : :(Amn_n /= vload (spu, (n - 1 , n - 1 )))
92+ quote
93+ $ (Expr (:meta , :inline ))
94+ Amn = getfield (vload (spb, $ unroll, mask), :data )
95+ Base. Cartesian. @nexprs $ N n -> begin
96+ Amn_n = getfield (Amn, n)
97+ Base. Cartesian. @nexprs (n - 1 ) k -> begin
98+ Amn_n = vfnmadd_fast (Amn_k, vload (spu, (k - 1 , n - 1 )), Amn_n)
99+ end
100+ $ scale
101+ end
102+ store_small_kern! (spa, sp, $ tostore, $ unroll, mask)
104103 end
105- store_small_kern! (
106- spa,
107- sp,
108- Amn,
109- spu,
110- (MM {W} (StaticInt (0 )), n),
111- n,
112- mask,
113- Val {UNIT} ()
114- )
115104 end
116105end
117- @inline function BdivU_small_kern_u! (
106+ @generated function BdivU_small_kern_u! (
118107 spa:: AbstractStridedPointer{T} ,
119108 sp,
120109 spb:: AbstractStridedPointer{T} ,
121110 spu:: AbstractStridedPointer{T} ,
122- N ,
111+ :: StaticInt{N} ,
123112 :: StaticInt{U} ,
124- :: Val{UNIT}
125- ) where {T,U,UNIT}
126- W = Int (VectorizationBase. pick_vector_width (T))
127- for n ∈ CloseOpen (N)
128- Amn = vload (spb, Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ), n)))
129- for k ∈ SafeCloseOpen (n)
130- Amk = vload (spa, Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ), k)))
131- Amn = vfnmadd_fast (Amk, vload (spu, (k, n)), Amn)
113+ :: Val{UNIT} ,
114+ :: StaticInt{W}
115+ ) where {T,U,UNIT,N,W}
116+ z = static (0 )
117+ if N == 1
118+ unroll = Unroll {1,W,U,1,W,zero(UInt),1} ((z, z))
119+ Amn = :(vload (spb, $ unroll))
120+ if ! UNIT
121+ Amn = :($ Amn / vload (spu, $ ((z, z))))
132122 end
133- store_small_kern! (
134- spa,
135- sp,
136- Amn,
137- spu,
138- Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ), n)),
139- n,
140- Val {UNIT} ()
141- )
123+ quote
124+ $ (Expr (:meta , :inline ))
125+ store_small_kern! (spa, sp, $ Amn, $ unroll)
126+ end
127+ else
128+ double_unroll =
129+ Unroll {2,1,N,1,W,zero(UInt),1} (Unroll {1,W,U,1,W,zero(UInt),1} ((z, z)))
130+ tostore = :(VecUnroll (Base. Cartesian. @ntuple $ N Amn))
131+ scale = UNIT ? nothing : :(Amn_n /= vload (spu, (n - 1 , n - 1 )))
132+ quote
133+ $ (Expr (:meta , :inline ))
134+ Amn = getfield (vload (spb, $ double_unroll), :data )
135+ Base. Cartesian. @nexprs $ N n -> begin
136+ Amn_n = getfield (Amn, n)
137+ Base. Cartesian. @nexprs (n - 1 ) k -> begin
138+ Amn_n = vfnmadd_fast (Amn_k, vload (spu, (k - 1 , n - 1 )), Amn_n)
139+ end
140+ $ scale
141+ end
142+ store_small_kern! (spa, sp, $ tostore, $ double_unroll)
143+ end
144+ end
145+ end
146+ @generated function BdivU_small_kern! (
147+ spa:: AbstractStridedPointer{T} ,
148+ sp,
149+ spb:: AbstractStridedPointer{T} ,
150+ spu:: AbstractStridedPointer{T} ,
151+ Nr:: Int ,
152+ mask:: AbstractMask{W} ,
153+ :: Val{UNIT}
154+ ) where {T,UNIT,W}
155+ quote
156+ # $(Expr(:meta, :inline))
157+ Base. Cartesian. @nif $ (W - 1 ) n -> n == Nr n ->
158+ BdivU_small_kern! (spa, sp, spb, spu, static (n), mask, $ (Val (UNIT)))
159+ end
160+ end
161+ @generated function BdivU_small_kern_u! (
162+ spa:: AbstractStridedPointer{T} ,
163+ sp,
164+ spb:: AbstractStridedPointer{T} ,
165+ spu:: AbstractStridedPointer{T} ,
166+ Nr:: Int ,
167+ :: StaticInt{U} ,
168+ :: Val{UNIT} ,
169+ :: StaticInt{W}
170+ ) where {T,U,UNIT,W}
171+ su = static (U)
172+ vu = Val (UNIT)
173+ sw = static (W)
174+ quote
175+ # $(Expr(:meta, :inline))
176+ Base. Cartesian. @nif $ (W - 1 ) n -> n == Nr n ->
177+ BdivU_small_kern_u! (spa, sp, spb, spu, static (n), $ su, $ vu, $ sw)
142178 end
143179end
144180
232268) where {W,U,UNIT}
233269 z = static (0 )
234270 quote
235- $ (Expr (:meta , :inline ))
271+ # $(Expr(:meta, :inline))
236272 # C = L \ A; L * C = A
237273 # A_{i,j} = L_{i,i}*C_{i,j} + \sum_{k=1}^{i-1}L_{i,k}C_{k,j}
238274 # C_{i,j} = L_{i,i} \ (A_{i,j} - \sum_{k=1}^{i-1}L_{i,k}C_{k,j})
328364) where {W,UNIT}
329365 z = static (0 )
330366 quote
331- $ (Expr (:meta , :inline ))
367+ # $(Expr(:meta, :inline))
332368 # Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block
333369 #
334370 # C = L \ A; L * C = A
382418 R <= 1 && throw (" Remainder of `<= 1` shouldn't be called, but had $R ." )
383419 R >= W && throw (" Reaminderof `>= $W ` shouldn't be called, but had $R ." )
384420 z = static (0 )
385- WS = static (W)
386421 q = quote
387- $ (Expr (:meta , :inline ))
422+ # $(Expr(:meta, :inline))
388423 # Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block
389424 #
390425 # C = L \ A; L * C = A
447482 push! (q. args, q3)
448483 return q
449484end
485+
450486@inline function rdiv_U! (
451487 spc:: AbstractStridedPointer{T} ,
452488 spa:: AbstractStridedPointer ,
467503 while m < M - WU + 1
468504 n = Nr
469505 if n > 0
470- BdivU_small_kern_u! (spc, nothing , spa, spu, n, UF, Val (UNIT))
506+ BdivU_small_kern_u! (spc, nothing , spa, spu, n, UF, Val (UNIT), WS )
471507 end
472508 for _ ∈ 1 : Nd
473509 rdiv_solve_W_u! (spc, nothing , spa, spu, n, WS, UF, Val (UNIT))
@@ -963,7 +999,7 @@ function rdiv_U!(
963999 while m < M - WU + 1
9641000 n = Nr
9651001 if n > 0
966- BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
1002+ BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT), WS )
9671003 end
9681004 for _ ∈ 1 : Nd
9691005 rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
@@ -995,7 +1031,7 @@ function rdiv_U!(
9951031 nothing
9961032end
9971033
998- @generated function ldiv_remainder ! (
1034+ @generated function _ldiv_remainder ! (
9991035 spc,
10001036 spa,
10011037 spu,
@@ -1011,16 +1047,20 @@ end
10111047 r >= W && throw (" Reaminderof `>= $W ` shouldn't be called, but had $r ." )
10121048 if r == 1
10131049 z = static (0 )
1050+ sub = Base. FastMath. sub_fast
1051+ mul = Base. FastMath. mul_fast
1052+ div = Base. FastMath. div_fast
10141053 vlxj = :(vload (spc, ($ z, j)))
10151054 if UNIT
10161055 vlxj = :(xj = $ vlxj)
10171056 else
10181057 vlxj = quote
1019- xj = $ vlxj / vload (spu, (j, j))
1058+ xj = $ div ( $ vlxj, vload (spu, (j, j) ))
10201059 vstore! (spc, xj, ($ z, j))
10211060 end
10221061 end
10231062 quote
1063+ $ (Expr (:meta , :inline ))
10241064 if pointer (spc) != pointer (spa)
10251065 for n = 0 : N- 1
10261066 vstore! (spc, vload (spa, ($ z, n)), ($ z, n))
@@ -1031,7 +1071,7 @@ end
10311071 for i = (j+ 1 ): N- 1
10321072 xi = vload (spc, ($ z, i))
10331073 Uji = vload (spu, (j, i))
1034- vstore! (spc, xi - xj * Uji, ($ z, i))
1074+ vstore! (spc, $ sub (xi, $ mul (xj, Uji)) , ($ z, i))
10351075 end
10361076 end
10371077 end
@@ -1070,8 +1110,8 @@ end
10701110 WS = static (W)
10711111 # US = static(U)
10721112 quote
1073- $ (Expr (:meta , :inline ))
1074- Base. Cartesian. @nif $ W w -> m == M - w w -> ldiv_remainder ! (
1113+ # $(Expr(:meta, :inline))
1114+ Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w -> _ldiv_remainder ! (
10751115 spc,
10761116 spa,
10771117 spu,
@@ -1111,7 +1151,16 @@ function rdiv_U!(
11111151 while m < M - WS + 1
11121152 n = Nr # non factor of W remainder
11131153 if n > 0
1114- BdivU_small_kern_u! (spc, nothing , spa, spu, n, StaticInt (1 ), Val (UNIT))
1154+ BdivU_small_kern_u! (
1155+ spc,
1156+ nothing ,
1157+ spa,
1158+ spu,
1159+ n,
1160+ StaticInt (1 ),
1161+ Val (UNIT),
1162+ WS
1163+ )
11151164 end
11161165 while n < N - (WU - 1 )
11171166 ldiv_solve_W_u! (spc, spa, spu, n, WS, UF, Val (UNIT))
0 commit comments