@@ -136,10 +136,87 @@ part_sets_to_test = map(_sizes_to_test) do sz
136136 ]
137137end
138138parts_to_test = vcat (part_sets_to_test... )
139- @testset " Size=$szA *$szB " for (szA, szB) in sizes_to_test
140- @testset " Partitioning=$partA *$partB " for (partA,partB) in parts_to_test
141- @testset " T=$T " for T in (Float32, Float64, ComplexF32, ComplexF64)
142- test_gemm! (T, szA, szB, partA, partB)
139+ @testset " GEMM" begin
140+ @testset " Size=$szA *$szB " for (szA, szB) in sizes_to_test
141+ @testset " Partitioning=$partA *$partB " for (partA,partB) in parts_to_test
142+ @testset " T=$T " for T in (Float32, Float64, ComplexF32, ComplexF64)
143+ test_gemm! (T, szA, szB, partA, partB)
144+ end
145+ end
146+ end
147+ end
148+
149+ function test_gemv! (T, szA, szB, partA, partB)
150+ @assert szA[2 ] == szB[1 ]
151+ szC = (szA[1 ],)
152+ @assert partA. blocksize[2 ] == partB. blocksize[1 ]
153+ partC = Blocks (partA. blocksize[1 ],)
154+
155+ A = rand (T, szA... )
156+ B = rand (T, szB... )
157+
158+ DA = distribute (A, partA)
159+ DB = distribute (B, partB)
160+
161+ # # Out-of-place gemm
162+ # No transA
163+ DC = DA * DB
164+ C = A * B
165+ @test collect (DC) ≈ C
166+
167+ if szA[1 ] == szB[1 ]
168+ # transA
169+ DC = DA' * DB
170+ C = A' * B
171+ @test collect (DC) ≈ C
172+ end
173+
174+ # # In-place gemm
175+ # No transA
176+ C = zeros (T, szC... )
177+ DC = distribute (C, partC)
178+ mul! (C, A, B)
179+ mul! (DC, DA, DB)
180+ @test collect (DC) ≈ C
181+
182+ if szA[1 ] == szB[1 ]
183+ # transA
184+ C = zeros (T, szC... )
185+ DC = distribute (C, partC)
186+ mul! (C, A' , B)
187+ mul! (DC, DA' , DB)
188+ @test collect (DC) ≈ C
189+ end
190+ end
191+
192+ _sizes_to_test = [
193+ (4 , 4 ),
194+ (7 , 7 ),
195+ (12 , 12 ),
196+ (16 , 16 ),
197+ ]
198+ size_sets_to_test = map (_sizes_to_test) do sz
199+ rows, cols = sz
200+ return [
201+ (rows, cols) => (cols,),
202+ (rows, cols ÷ 2 ) => (cols ÷ 2 ,),
203+ ]
204+ end
205+ sizes_to_test = vcat (size_sets_to_test... )
206+ part_sets_to_test = map (_sizes_to_test) do sz
207+ rows, cols = sz
208+ return [
209+ Blocks (rows, cols) => Blocks (cols,),
210+ Blocks (rows, cols ÷ 2 ) => Blocks (cols ÷ 2 ,),
211+ ]
212+ end
213+ parts_to_test = vcat (part_sets_to_test... )
214+ @testset " GEMV" begin
215+ @testset " Size=$szA *$szB " for (szA, szB) in sizes_to_test
216+ @testset " Partitioning=$partA *$partB " for (partA,partB) in parts_to_test
217+ @testset " T=$T " for T in (Float32, Float64, ComplexF32, ComplexF64)
218+ test_gemv! (T, szA, szB, partA, partB)
219+ end
143220 end
144221 end
145222end
0 commit comments