Skip to content

Commit 5529310

Browse files
authored
Resolve issues with Zygote.jl and mean_vector (#415)
* resolve issues with AD of mean_vector * add tests for mean_vector with ColVecs & RowVecs * bump version to 0.5.22
1 parent 8d431a2 commit 5529310

File tree

3 files changed

+34
-15
lines changed

3 files changed

+34
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractGPs"
22
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.5.21"
4+
version = "0.5.22"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/mean_function.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,12 @@ mean_vector(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x))
4444
4545
A wrapper around whatever unary function you fancy. Must be able to be mapped over an
4646
`AbstractVector` of inputs.
47-
48-
# Warning
49-
`CustomMean` is generally sufficient for testing purposes, but care should be taken if
50-
attempting to differentiate through `mean_vector` with a `CustomMean` when using
51-
`Zygote.jl`. In particular, `mean_vector(m::CustomMean, x)` is implemented as `map(m.f, x)`,
52-
which when `x` is a `ColVecs` or `RowVecs` will not differentiate correctly.
53-
54-
In such cases, you should implement `mean_vector` directly for your custom mean.
55-
For example, if `f(x) = sum(x)`, you might implement `mean_vector` as
56-
```julia
57-
mean_vector(::CustomMean{typeof(f)}, x::ColVecs) = vec(sum(x.X; dims=1))
58-
mean_vector(::CustomMean{typeof(f)}, x::RowVecs) = vec(sum(x.X; dims=2))
59-
```
60-
which avoids ever applying `map` to a `ColVecs` or `RowVecs`.
6147
"""
6248
struct CustomMean{Tf} <: MeanFunction
6349
f::Tf
6450
end
6551

6652
mean_vector(m::CustomMean, x::AbstractVector) = map(m.f, x)
53+
54+
mean_vector(m::CustomMean, x::ColVecs) = map(m.f, eachcol(x.X))
55+
mean_vector(m::CustomMean, x::RowVecs) = map(m.f, eachrow(x.X))

test/mean_function.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,34 @@
2424
differentiable_mean_function_tests(rng, m, x)
2525
end
2626
end
27+
28+
@testset "ColVecs & RowVecs" begin
29+
m = custom_mean_testcase.mean_function
30+
31+
@test mean_vector(m, xD_colvecs) == map(foo_mean, eachcol(xD_colvecs.X))
32+
@test mean_vector(m, xD_rowvecs) == map(foo_mean, eachrow(xD_rowvecs.X))
33+
end
34+
35+
# This test fails without the specialized methods
36+
# `mean_vector(m::CustomMean, x::ColVecs)`
37+
# `mean_vector(m::CustomMean, x::RowVecs)`
38+
@testset "Zygote gradients" begin
39+
X = [1.;; 2.;; 3.;;]
40+
y = [1., 2., 3.]
41+
foo_mean = x -> sum(abs2, x)
42+
43+
function construct_finite_gp(X, lengthscale, noise)
44+
mean = CustomMean(foo_mean)
45+
kernel = with_lengthscale(Matern52Kernel(), lengthscale)
46+
return GP(mean, kernel)(X, noise)
47+
end
48+
49+
function loglike(lengthscale, noise)
50+
gp = construct_finite_gp(X, lengthscale, noise)
51+
return logpdf(gp, y)
52+
end
53+
54+
@test Zygote.gradient(n -> loglike(1., n), 1.)[1] isa Real
55+
@test Zygote.gradient(l -> loglike(l, 1.), 1.)[1] isa Real
56+
end
2757
end

0 commit comments

Comments
 (0)