Skip to content

Commit ce870f7

Browse files
committed
Fix Enzyme tests again
1 parent 098ef8d commit ce870f7

File tree

2 files changed

+41
-60
lines changed

2 files changed

+41
-60
lines changed

test/ad/corr.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,10 @@ using Enzyme: ForwardMode
99
x = rand(dist)
1010
y = b(x)
1111

12-
if adtype isa AutoEnzyme{<:ForwardMode} && d == 1
13-
# For d == 1, y has length 0, and DI doesn't handle this well
14-
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
15-
@test_throws DivideError test_ad(y -> sum(transform(b, binv(y))), adtype, y)
16-
@test_throws DivideError test_ad(y -> sum(transform(binv, y)), adtype, y)
17-
else
18-
# roundtrip
19-
test_ad(y -> sum(transform(b, binv(y))), adtype, y)
20-
# inverse only
21-
test_ad(y -> sum(transform(binv, y)), adtype, y)
22-
end
12+
# roundtrip
13+
test_ad(y -> sum(transform(b, binv(y))), adtype, y)
14+
# inverse only
15+
test_ad(y -> sum(transform(binv, y)), adtype, y)
2316
end
2417
end
2518

test/ad/enzyme.jl

Lines changed: 37 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,67 +5,55 @@ using Enzyme
55
using EnzymeTestUtils: test_forward, test_reverse
66
using Test
77

8-
@testset "Enzyme: Bijectors.find_alpha" begin
9-
x = randn()
10-
y = expm1(randn())
11-
z = randn()
8+
# This entire test suite is broken on 1.11.
9+
#
10+
# https://github.com/EnzymeAD/Enzyme.jl/issues/2121
11+
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2470766968
12+
#
13+
# Ideally we'd use `@test_throws`. However, that doesn't work because
14+
# `test_forward` itself calls `@test`, and the error is captured by that
15+
# `@test`, not our `@test_throws`. Consequently `@test_throws` doesn't actually
16+
# see any error. Weird Julia behaviour.
17+
18+
@static if VERSION < v"1.11"
19+
@testset "Enzyme: Bijectors.find_alpha" begin
20+
x = randn()
21+
y = expm1(randn())
22+
z = randn()
23+
24+
@testset "forward" begin
25+
# No batches
26+
@testset for RT in (Const, Duplicated, DuplicatedNoNeed),
27+
Tx in (Const, Duplicated),
28+
Ty in (Const, Duplicated),
29+
Tz in (Const, Duplicated)
1230

13-
@testset "forward" begin
14-
# No batches
15-
@testset for RT in (Const, Duplicated, DuplicatedNoNeed),
16-
Tx in (Const, Duplicated),
17-
Ty in (Const, Duplicated),
18-
Tz in (Const, Duplicated)
19-
20-
if VERSION >= v"1.11" &&
21-
(!(RT <: Const) || (Tx <: Const && Ty <: Const && Tz <: Const))
22-
# https://github.com/EnzymeAD/Enzyme.jl/issues/2121
23-
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2470766968
24-
#
25-
# Ideally we'd use `@test_throws`. However, that doesn't work
26-
# because `test_forward` itself calls `@test`, and the error is
27-
# captured by that `@test`, not our `@test_throws`.
28-
# Consequently `@test_throws` doesn't actually see any error.
29-
# Weird Julia behaviour.
30-
continue
31-
else
3231
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
3332
end
34-
end
3533

36-
# Batches
37-
@testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed),
38-
Tx in (Const, BatchDuplicated),
39-
Ty in (Const, BatchDuplicated),
40-
Tz in (Const, BatchDuplicated)
34+
# Batches
35+
@testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed),
36+
Tx in (Const, BatchDuplicated),
37+
Ty in (Const, BatchDuplicated),
38+
Tz in (Const, BatchDuplicated)
4139

42-
if VERSION >= v"1.11" &&
43-
(!(RT <: Const) || (Tx <: Const && Ty <: Const && Tz <: Const))
44-
# See above
45-
continue
46-
else
4740
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
4841
end
4942
end
50-
end
51-
@testset "reverse" begin
52-
# No batches
53-
@testset for RT in (Const, Active),
54-
Tx in (Const, Active),
55-
Ty in (Const, Active),
56-
Tz in (Const, Active)
43+
@testset "reverse" begin
44+
# No batches
45+
@testset for RT in (Const, Active),
46+
Tx in (Const, Active),
47+
Ty in (Const, Active),
48+
Tz in (Const, Active)
5749

58-
if VERSION >= v"1.11"
59-
# See above
60-
continue
61-
else
6250
test_reverse(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
6351
end
64-
end
6552

66-
# TODO: Test batch mode
67-
# This is a bit problematic since Enzyme does not support all combinations of activities currently
68-
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2480468728
53+
# TODO: Test batch mode
54+
# This is a bit problematic since Enzyme does not support all combinations of activities currently
55+
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2480468728
56+
end
6957
end
7058
end
7159

0 commit comments

Comments
 (0)