|
5 | 5 | xD_colvecs = ColVecs(randn(rng, D, N)) |
6 | 6 | xD_rowvecs = RowVecs(randn(rng, N, D)) |
7 | 7 |
|
8 | | - @testset "ZeroMean" begin |
9 | | - m = ZeroMean{Float64}() |
| 8 | + zero_mean_testcase = (; mean_function=ZeroMean(), calc_expected=_ -> zeros(N)) |
10 | 9 |
|
11 | | - for x in [x1, xD_colvecs, xD_rowvecs] |
12 | | - @test AbstractGPs._map_meanfunction(m, x) == zeros(N) |
13 | | - #differentiable_mean_function_tests(m, randn(rng, N), x) |
14 | | - |
15 | | - # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but |
16 | | - # currently ChainRulesTestUtils isn't up to handling this, so this will have to do |
17 | | - # for now. |
18 | | - y, pb = rrule(AbstractGPs._map_meanfunction, m, x) |
19 | | - @test y == AbstractGPs._map_meanfunction(m, x) |
20 | | - Δmap, Δf, Δx = pb(randn(rng, N)) |
21 | | - @test iszero(Δmap) |
22 | | - @test iszero(Δf) |
23 | | - @test iszero(Δx) |
24 | | - end |
25 | | - end |
26 | | - |
27 | | - @testset "ConstMean" begin |
28 | | - c = randn(rng) |
29 | | - m = ConstMean(c) |
30 | | - |
31 | | - for x in [x1, xD_colvecs, xD_rowvecs] |
32 | | - @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) |
33 | | - #differentiable_mean_function_tests(m, randn(rng, N), x) |
34 | | - end |
35 | | - end |
| 10 | + c = randn(rng) |
| 11 | + const_mean_testcase = (; mean_function=ConstMean(c), calc_expected=_ -> fill(c, N)) |
36 | 12 |
|
37 | | - @testset "CustomMean" begin |
38 | | - foo_mean = x -> sum(abs2, x) |
39 | | - m = CustomMean(foo_mean) |
| 13 | + foo_mean = x -> sum(abs2, x) |
| 14 | + custom_mean_testcase = (; |
| 15 | + mean_function=CustomMean(foo_mean), calc_expected=x -> map(foo_mean, x) |
| 16 | + ) |
40 | 17 |
|
| 18 | + @testset "$(typeof(testcase.mean_function))" for testcase in [ |
| 19 | + zero_mean_testcase, const_mean_testcase, custom_mean_testcase |
| 20 | + ] |
41 | 21 | for x in [x1, xD_colvecs, xD_rowvecs] |
42 | | - @test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x) |
43 | | - #differentiable_mean_function_tests(m, randn(rng, N), x) |
| 22 | + m = testcase.mean_function |
| 23 | + @test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x) |
| 24 | + differentiable_mean_function_tests(rng, m, x) |
44 | 25 | end |
45 | 26 | end |
46 | 27 | end |
0 commit comments