Skip to content

Commit 195e583

Browse files
Copilotyebai
andcommitted
Implement frule!! for find_alpha with integer arguments and enable forward mode tests
Co-authored-by: yebai <[email protected]>
1 parent 3b571d1 commit 195e583

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

ext/BijectorsMooncakeExt.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,22 @@ using Bijectors: find_alpha, ChainRulesCore
1515
# unusual Integer type is encountered.
1616
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})
1717

18-
# TODO: This needs a corresponding frule!! as well for it to work on forward-mode Mooncake.
18+
function Mooncake.frule!!(
19+
f::Mooncake.Dual{typeof(find_alpha)}, x::Mooncake.Dual{P}, y::Mooncake.Dual{P}, z::Mooncake.Dual{I}
20+
) where {P<:Base.IEEEFloat,I<:Integer}
21+
# Require that the integer is non-differentiable.
22+
if !(Mooncake.tangent(z) isa Mooncake.NoTangent)
23+
msg = "Integer argument has tangent type $(typeof(Mooncake.tangent(z))), should be NoTangent."
24+
throw(ArgumentError(msg))
25+
end
26+
# Convert Mooncake.NoTangent to ChainRulesCore.NoTangent for the integer argument
27+
out, tangent_out = ChainRulesCore.frule(
28+
(ChainRulesCore.NoTangent(), Mooncake.tangent(x), Mooncake.tangent(y), ChainRulesCore.NoTangent()),
29+
find_alpha, Mooncake.primal(x), Mooncake.primal(y), Mooncake.primal(z)
30+
)
31+
return Mooncake.Dual(out, tangent_out)
32+
end
33+
1934
function Mooncake.rrule!!(
2035
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
2136
) where {P<:Base.IEEEFloat,I<:Integer}

test/ad/chainrules.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ end
3131

3232
if @isdefined Mooncake
3333
rng = Xoshiro(123456)
34-
# TODO: Enable Mooncake.ForwardMode as well.
35-
@testset "$mode" for mode in (Mooncake.ReverseMode,)
34+
@testset "$mode" for mode in (Mooncake.ReverseMode, Mooncake.ForwardMode)
3635
Mooncake.TestUtils.test_rule(
3736
rng,
3837
Bijectors.find_alpha,

0 commit comments

Comments
 (0)