@@ -15,7 +15,22 @@ using Bijectors: find_alpha, ChainRulesCore
1515# unusual Integer type is encountered.
1616@is_primitive (MinimalCtx, Tuple{typeof (find_alpha),P,P,Integer} where {P<: Base.IEEEFloat })
1717
18- # TODO : This needs a corresponding frule!! as well for it to work on forward-mode Mooncake.
18+ function Mooncake. frule!! (
19+ f:: Mooncake.Dual{typeof(find_alpha)} , x:: Mooncake.Dual{P} , y:: Mooncake.Dual{P} , z:: Mooncake.Dual{I}
20+ ) where {P<: Base.IEEEFloat ,I<: Integer }
21+ # Require that the integer is non-differentiable.
22+ if ! (Mooncake. tangent (z) isa Mooncake. NoTangent)
23+ msg = " Integer argument has tangent type $(typeof (Mooncake. tangent (z))) , should be NoTangent."
24+ throw (ArgumentError (msg))
25+ end
26+ # Convert Mooncake.NoTangent to ChainRulesCore.NoTangent for the integer argument
27+ out, tangent_out = ChainRulesCore. frule (
28+ (ChainRulesCore. NoTangent (), Mooncake. tangent (x), Mooncake. tangent (y), ChainRulesCore. NoTangent ()),
29+ find_alpha, Mooncake. primal (x), Mooncake. primal (y), Mooncake. primal (z)
30+ )
31+ return Mooncake. Dual (out, tangent_out)
32+ end
33+
1934function Mooncake. rrule!! (
2035 :: CoDual{typeof(find_alpha)} , x:: CoDual{P} , y:: CoDual{P} , z:: CoDual{I}
2136) where {P<: Base.IEEEFloat ,I<: Integer }
0 commit comments