Skip to content

Commit cd8f069

Browse files
st--devmotionwilltebbutt
authored
reactivate AD tests: mean functions (#313)
* reactivate mean function AD tests * extend mean function tests to ColVecs/RowVecs * unify testcases * remove rrules and ChainRulesCore Co-authored-by: David Widmann <[email protected]> Co-authored-by: willtebbutt <[email protected]>
1 parent d99311e commit cd8f069

File tree

7 files changed

+23
-67
lines changed

7 files changed

+23
-67
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
name = "AbstractGPs"
22
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.5.11"
4+
version = "0.5.12"
55

66
[deps]
7-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
87
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
98
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
109
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
@@ -19,7 +18,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1918
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2019

2120
[compat]
22-
ChainRulesCore = "1"
2321
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
2422
FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
2523
IrrationalConstants = "0.1"

src/AbstractGPs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module AbstractGPs
22

3-
using ChainRulesCore
43
using Distributions
54
using FillArrays
65
using LinearAlgebra

src/mean_function.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ This is an AbstractGPs-internal workaround for AD issues; ideally we would just
1212
"""
1313
_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x))
1414

15-
function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector)
16-
map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent())
17-
return _map_meanfunction(m, x), map_ZeroMean_pullback
18-
end
19-
2015
ZeroMean() = ZeroMean{Float64}()
2116

2217
"""
@@ -40,4 +35,4 @@ struct CustomMean{Tf} <: MeanFunction
4035
f::Tf
4136
end
4237

43-
_map_meanfunction(f::CustomMean, x::AbstractVector) = map(f.f, x)
38+
_map_meanfunction(m::CustomMean, x::AbstractVector) = map(m.f, x)

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[deps]
2-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
32
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
43
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
54
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -14,7 +13,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1413
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1514

1615
[compat]
17-
ChainRulesCore = "1"
1816
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
1917
Documenter = "0.24, 0.25, 0.26, 0.27"
2018
FillArrays = "0.11, 0.12, 0.13"

test/mean_function.jl

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,23 @@
55
xD_colvecs = ColVecs(randn(rng, D, N))
66
xD_rowvecs = RowVecs(randn(rng, N, D))
77

8-
@testset "ZeroMean" begin
9-
m = ZeroMean{Float64}()
8+
zero_mean_testcase = (; mean_function=ZeroMean(), calc_expected=_ -> zeros(N))
109

11-
for x in [x1, xD_colvecs, xD_rowvecs]
12-
@test AbstractGPs._map_meanfunction(m, x) == zeros(N)
13-
#differentiable_mean_function_tests(m, randn(rng, N), x)
14-
15-
# Manually verify the ChainRule. Really, this should employ FiniteDifferences, but
16-
# currently ChainRulesTestUtils isn't up to handling this, so this will have to do
17-
# for now.
18-
y, pb = rrule(AbstractGPs._map_meanfunction, m, x)
19-
@test y == AbstractGPs._map_meanfunction(m, x)
20-
Δmap, Δf, Δx = pb(randn(rng, N))
21-
@test iszero(Δmap)
22-
@test iszero(Δf)
23-
@test iszero(Δx)
24-
end
25-
end
26-
27-
@testset "ConstMean" begin
28-
c = randn(rng)
29-
m = ConstMean(c)
30-
31-
for x in [x1, xD_colvecs, xD_rowvecs]
32-
@test AbstractGPs._map_meanfunction(m, x) == fill(c, N)
33-
#differentiable_mean_function_tests(m, randn(rng, N), x)
34-
end
35-
end
10+
c = randn(rng)
11+
const_mean_testcase = (; mean_function=ConstMean(c), calc_expected=_ -> fill(c, N))
3612

37-
@testset "CustomMean" begin
38-
foo_mean = x -> sum(abs2, x)
39-
m = CustomMean(foo_mean)
13+
foo_mean = x -> sum(abs2, x)
14+
custom_mean_testcase = (;
15+
mean_function=CustomMean(foo_mean), calc_expected=x -> map(foo_mean, x)
16+
)
4017

18+
@testset "$(typeof(testcase.mean_function))" for testcase in [
19+
zero_mean_testcase, const_mean_testcase, custom_mean_testcase
20+
]
4121
for x in [x1, xD_colvecs, xD_rowvecs]
42-
@test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x)
43-
#differentiable_mean_function_tests(m, randn(rng, N), x)
22+
m = testcase.mean_function
23+
@test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x)
24+
differentiable_mean_function_tests(rng, m, x)
4425
end
4526
end
4627
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ using AbstractGPs:
2323
TestUtils
2424

2525
using Documenter
26-
using ChainRulesCore
2726
using Distributions: MvNormal, PDMat, loglikelihood, Distributions
2827
using FillArrays
2928
using FiniteDifferences

test/test_util.jl

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ end
7373
Test _very_ basic consistency properties of the mean function `m`.
7474
"""
7575
function mean_function_tests(m::MeanFunction, x::AbstractVector)
76-
@test AbstractGPs._map_meanfunction(m, x) isa AbstractVector
77-
@test length(ew(m, x)) == length(x)
76+
mean = AbstractGPs._map_meanfunction(m, x)
77+
@test mean isa AbstractVector
78+
@test length(mean) == length(x)
7879
end
7980

8081
"""
@@ -87,34 +88,19 @@ end
8788
Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximately correct.
8889
"""
8990
function differentiable_mean_function_tests(
90-
m::MeanFunction,
91-
::AbstractVector{<:Real},
92-
x::AbstractVector{<:Real};
93-
rtol=_rtol,
94-
atol=_atol,
91+
m::MeanFunction, ȳ::AbstractVector, x::AbstractVector; rtol=_rtol, atol=_atol
9592
)
9693
# Run forward tests.
9794
mean_function_tests(m, x)
9895

9996
# Check adjoint.
10097
@assert length(ȳ) == length(x)
101-
return adjoint_test(x -> ew(m, x), ȳ, x; rtol=rtol, atol=atol)
98+
adjoint_test(
99+
x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol
100+
)
101+
return nothing
102102
end
103103

104-
# function differentiable_mean_function_tests(
105-
# m::MeanFunction,
106-
# ȳ::AbstractVector{<:Real},
107-
# x::ColVecs{<:Real};
108-
# rtol=_rtol,
109-
# atol=_atol,
110-
# )
111-
# # Run forward tests.
112-
# mean_function_tests(m, x)
113-
114-
# @assert length(ȳ) == length(x)
115-
# adjoint_test(X->ew(m, ColVecs(X)), ȳ, x.X; rtol=rtol, atol=atol)
116-
# end
117-
118104
function differentiable_mean_function_tests(
119105
rng::AbstractRNG, m::MeanFunction, x::AbstractVector; rtol=_rtol, atol=_atol
120106
)

0 commit comments

Comments
 (0)