Skip to content

Commit 79552e3

Browse files
committed
DArray: Add matrix-vector multiply
1 parent cfc9bad commit 79552e3

File tree

3 files changed

+176
-6
lines changed

3 files changed

+176
-6
lines changed

docs/src/darray.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,12 +693,13 @@ From `Statistics`:
693693
- `std`
694694
695695
From `LinearAlgebra`:
696+
- `norm`
696697
- `transpose`/`adjoint` (Out-of-place transpose)
697698
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
698-
- `mul!` (In-place Matrix-Matrix multiply)
699+
- `mul!` (In-place Matrix-Matrix and Matrix-Vector multiply)
699700
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
700701
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))
701702
702703
From `AbstractFFTs`:
703704
- `fft`/`fft!`
704-
- `ifft`/`ifft!`
705+
- `ifft`/`ifft!`

src/array/mul.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,95 @@ end
405405
A[i, j] = C[i, j]
406406
end
407407
end
408+
function LinearAlgebra.generic_matvecmul!(
409+
C::DVector{T},
410+
transA::Char,
411+
A::DMatrix{T},
412+
B::DVector{T},
413+
_add::LinearAlgebra.MulAddMul,
414+
) where {T}
415+
partC, partA, partB = _repartition_matvecmul(C, A, B, transA)
416+
return maybe_copy_buffered(C=>partC, A=>partA, B=>partB) do C, A, B
417+
return gemv_dagger!(C, transA, A, B, _add)
418+
end
419+
end
420+
function _repartition_matvecmul(C, A, B, transA::Char)
421+
partA = A.partitioning.blocksize
422+
partB = B.partitioning.blocksize
423+
istransA = transA == 'T' || transA == 'C'
424+
dimA = !istransA ? partA[1] : partA[2]
425+
dimA_other = !istransA ? partA[2] : partA[1]
426+
dimB = partB[1]
427+
428+
# If A and B rows/cols don't match, fix them
429+
# Uses the smallest blocking of all dimensions
430+
sz = minimum((partA[1], partA[2], partB[1]))
431+
if dimA_other != dimB
432+
dimA_other = dimB = sz
433+
if !istransA
434+
partA = (partA[1], sz)
435+
else
436+
partA = (sz, partA[2])
437+
end
438+
end
439+
partC = (dimA,)
440+
return Blocks(partC...), Blocks(partA...), Blocks(partB...)
441+
end
442+
function gemv_dagger!(
443+
C::DVector{T},
444+
transA::Char,
445+
A::DMatrix{T},
446+
B::DVector{T},
447+
_add::LinearAlgebra.MulAddMul,
448+
) where {T}
449+
Ac = A.chunks
450+
Bc = B.chunks
451+
Cc = C.chunks
452+
Amt, Ant = size(Ac)
453+
Bmt = size(Bc)[1]
454+
Cmt = size(Cc)[1]
455+
456+
alpha = T(_add.alpha)
457+
beta = T(_add.beta)
458+
459+
if Ant != Bmt
460+
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt)"))
461+
end
462+
if Amt != Cmt
463+
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but C has number of blocks ($Cmt)"))
464+
end
465+
466+
Dagger.spawn_datadeps() do
467+
for m in range(1, Cmt)
468+
if transA == 'N'
469+
# A: NoTrans
470+
for k in range(1, Ant)
471+
mzone = k == 1 ? beta : T(1.0)
472+
Dagger.@spawn BLAS.gemv!(
473+
transA,
474+
alpha,
475+
In(Ac[m, k]),
476+
In(Bc[k]),
477+
mzone,
478+
InOut(Cc[m]),
479+
)
480+
end
481+
else
482+
# A: [Conj]Trans
483+
for k in range(1, Amt)
484+
mzone = k == 1 ? beta : T(1.0)
485+
Dagger.@spawn BLAS.gemv!(
486+
transA,
487+
alpha,
488+
In(Ac[k, m]),
489+
In(Bc[k]),
490+
mzone,
491+
InOut(Cc[m]),
492+
)
493+
end
494+
end
495+
end
496+
end
497+
498+
return C
499+
end

test/array/linalg/matmul.jl

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,87 @@ part_sets_to_test = map(_sizes_to_test) do sz
136136
]
137137
end
138138
parts_to_test = vcat(part_sets_to_test...)
139-
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
140-
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
141-
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
142-
test_gemm!(T, szA, szB, partA, partB)
139+
@testset "GEMM" begin
140+
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
141+
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
142+
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
143+
test_gemm!(T, szA, szB, partA, partB)
144+
end
145+
end
146+
end
147+
end
148+
149+
function test_gemv!(T, szA, szB, partA, partB)
150+
@assert szA[2] == szB[1]
151+
szC = (szA[1],)
152+
@assert partA.blocksize[2] == partB.blocksize[1]
153+
partC = Blocks(partA.blocksize[1],)
154+
155+
A = rand(T, szA...)
156+
B = rand(T, szB...)
157+
158+
DA = distribute(A, partA)
159+
DB = distribute(B, partB)
160+
161+
## Out-of-place gemm
162+
# No transA
163+
DC = DA * DB
164+
C = A * B
165+
@test collect(DC) C
166+
167+
if szA[1] == szB[1]
168+
# transA
169+
DC = DA' * DB
170+
C = A' * B
171+
@test collect(DC) C
172+
end
173+
174+
## In-place gemm
175+
# No transA
176+
C = zeros(T, szC...)
177+
DC = distribute(C, partC)
178+
mul!(C, A, B)
179+
mul!(DC, DA, DB)
180+
@test collect(DC) C
181+
182+
if szA[1] == szB[1]
183+
# transA
184+
C = zeros(T, szC...)
185+
DC = distribute(C, partC)
186+
mul!(C, A', B)
187+
mul!(DC, DA', DB)
188+
@test collect(DC) C
189+
end
190+
end
191+
192+
_sizes_to_test = [
193+
(4, 4),
194+
(7, 7),
195+
(12, 12),
196+
(16, 16),
197+
]
198+
size_sets_to_test = map(_sizes_to_test) do sz
199+
rows, cols = sz
200+
return [
201+
(rows, cols) => (cols,),
202+
(rows, cols ÷ 2) => (cols ÷ 2,),
203+
]
204+
end
205+
sizes_to_test = vcat(size_sets_to_test...)
206+
part_sets_to_test = map(_sizes_to_test) do sz
207+
rows, cols = sz
208+
return [
209+
Blocks(rows, cols) => Blocks(cols,),
210+
Blocks(rows, cols ÷ 2) => Blocks(cols ÷ 2,),
211+
]
212+
end
213+
parts_to_test = vcat(part_sets_to_test...)
214+
@testset "GEMV" begin
215+
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
216+
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
217+
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
218+
test_gemv!(T, szA, szB, partA, partB)
219+
end
143220
end
144221
end
145222
end

0 commit comments

Comments
 (0)