Skip to content

Commit 1bf9cf9

Browse files
Apply suggestions from code review
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 195e583 commit 1bf9cf9

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

ext/BijectorsMooncakeExt.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ using Bijectors: find_alpha, ChainRulesCore
1616
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})
1717

1818
function Mooncake.frule!!(
19-
f::Mooncake.Dual{typeof(find_alpha)}, x::Mooncake.Dual{P}, y::Mooncake.Dual{P}, z::Mooncake.Dual{I}
19+
f::Mooncake.Dual{typeof(find_alpha)},
20+
x::Mooncake.Dual{P},
21+
y::Mooncake.Dual{P},
22+
z::Mooncake.Dual{I},
2023
) where {P<:Base.IEEEFloat,I<:Integer}
2124
# Require that the integer is non-differentiable.
2225
if !(Mooncake.tangent(z) isa Mooncake.NoTangent)
@@ -25,8 +28,16 @@ function Mooncake.frule!!(
2528
end
2629
# Convert Mooncake.NoTangent to ChainRulesCore.NoTangent for the integer argument
2730
out, tangent_out = ChainRulesCore.frule(
28-
(ChainRulesCore.NoTangent(), Mooncake.tangent(x), Mooncake.tangent(y), ChainRulesCore.NoTangent()),
29-
find_alpha, Mooncake.primal(x), Mooncake.primal(y), Mooncake.primal(z)
31+
(
32+
ChainRulesCore.NoTangent(),
33+
Mooncake.tangent(x),
34+
Mooncake.tangent(y),
35+
ChainRulesCore.NoTangent(),
36+
),
37+
find_alpha,
38+
Mooncake.primal(x),
39+
Mooncake.primal(y),
40+
Mooncake.primal(z),
3041
)
3142
return Mooncake.Dual(out, tangent_out)
3243
end

0 commit comments

Comments
 (0)