Skip to content

Commit 28d0ca8

Browse files
committed
flows only fail on 1.11 apparently
1 parent d67f0c0 commit 28d0ca8

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

test/ad/flows.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
using Enzyme: ForwardMode
22

33
@testset "PlanarLayer: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
4+
ENZYME_FWD_AND_1p11 = VERSION >= v"1.11" && adtype isa AutoEnzyme{<:ForwardMode}
5+
46
# logpdf of a flow with a planar layer and two-dimensional inputs
57
function f(θ)
68
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
79
flow = transformed(MvNormal(zeros(2), I), layer)
810
x = θ[6:7]
911
return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
1012
end
11-
if adtype isa AutoEnzyme{<:ForwardMode}
13+
if ENZYME_FWD_AND_1p11
1214
@test_throws Enzyme.Compiler.EnzymeInternalError test_ad(f, adtype, randn(7))
1315
else
1416
test_ad(f, adtype, randn(7))
@@ -20,7 +22,7 @@ using Enzyme: ForwardMode
2022
x = reshape(θ[6:end], 2, :)
2123
return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x))
2224
end
23-
if adtype isa AutoEnzyme{<:ForwardMode}
25+
if ENZYME_FWD_AND_1p11
2426
@test_throws Enzyme.Compiler.EnzymeInternalError test_ad(g, adtype, randn(11))
2527
else
2628
test_ad(g, adtype, randn(11))
@@ -33,7 +35,7 @@ using Enzyme: ForwardMode
3335
x = θ[6:7]
3436
return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
3537
end
36-
if adtype isa AutoEnzyme{<:ForwardMode}
38+
if ENZYME_FWD_AND_1p11
3739
@test_throws Enzyme.Compiler.EnzymeInternalError test_ad(f, adtype, randn(7))
3840
else
3941
test_ad(f, adtype, randn(7))
@@ -45,7 +47,7 @@ using Enzyme: ForwardMode
4547
x = reshape(θ[6:end], 2, :)
4648
return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x))
4749
end
48-
if adtype isa AutoEnzyme{<:ForwardMode}
50+
if ENZYME_FWD_AND_1p11
4951
@test_throws Enzyme.Compiler.EnzymeInternalError test_ad(g, adtype, randn(11))
5052
else
5153
test_ad(g, adtype, randn(11))

0 commit comments

Comments
 (0)