@@ -16,7 +16,10 @@ using Bijectors: find_alpha, ChainRulesCore
1616@is_primitive (MinimalCtx, Tuple{typeof (find_alpha),P,P,Integer} where {P<: Base.IEEEFloat })
1717
1818function Mooncake. frule!! (
19- f:: Mooncake.Dual{typeof(find_alpha)} , x:: Mooncake.Dual{P} , y:: Mooncake.Dual{P} , z:: Mooncake.Dual{I}
19+ f:: Mooncake.Dual{typeof(find_alpha)} ,
20+ x:: Mooncake.Dual{P} ,
21+ y:: Mooncake.Dual{P} ,
22+ z:: Mooncake.Dual{I} ,
2023) where {P<: Base.IEEEFloat ,I<: Integer }
2124 # Require that the integer is non-differentiable.
2225 if ! (Mooncake. tangent (z) isa Mooncake. NoTangent)
@@ -25,8 +28,16 @@ function Mooncake.frule!!(
2528 end
2629 # Convert Mooncake.NoTangent to ChainRulesCore.NoTangent for the integer argument
2730 out, tangent_out = ChainRulesCore. frule (
28- (ChainRulesCore. NoTangent (), Mooncake. tangent (x), Mooncake. tangent (y), ChainRulesCore. NoTangent ()),
29- find_alpha, Mooncake. primal (x), Mooncake. primal (y), Mooncake. primal (z)
31+ (
32+ ChainRulesCore. NoTangent (),
33+ Mooncake. tangent (x),
34+ Mooncake. tangent (y),
35+ ChainRulesCore. NoTangent (),
36+ ),
37+ find_alpha,
38+ Mooncake. primal (x),
39+ Mooncake. primal (y),
40+ Mooncake. primal (z),
3041 )
3142 return Mooncake. Dual (out, tangent_out)
3243end
0 commit comments