Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ ChainRulesCore = "1.0.0"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "ChainRulesTestUtils", "Zygote"]
3 changes: 2 additions & 1 deletion src/ChainRulesDeclarationHelpers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
module ChainRulesDeclarationHelpers

export @rrule_from_frule
include("rrule_from_frule.jl")
end # module
32 changes: 32 additions & 0 deletions src/rrule_from_frule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using ChainRulesCore

"""
@rrule_from_frule(signature_expression)

A helper to define an rrule by calling back into AD on an already defined frule.
"""
macro rrule_from_frule(signature_expression)
@assert Meta.isexpr(signature_expression, :call)
@assert length(signature_expression.args) == 2 "Only single-argument functions are implemented."
# TODO add support for multiple arguments, varargs, kwargs

f = signature_expression.args[1]
arg = signature_expression.args[2]
@assert Meta.isexpr(arg, :(::), 2)

return rrule_from_frule_expr(__source__, f, arg)
end

function rrule_from_frule_expr(__source__, f, arg)
f_instance_name = gensym(Symbol(:instance_, Symbol(f)))
return quote
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, $f_instance_name::Core.Typeof($(esc(f))), $arg)
$(__source__)
pushforward(Δfarg...) = frule(Δfarg, $f_instance_name, $arg)[2]
_, back = rrule_via_ad(config, pushforward, $f_instance_name, $arg)
y = $f_instance_name($arg) # TODO optimize away redundant primal computation
f_pullback(Δy) = back(Δy)[2:end]
return y, f_pullback
end
end
end
28 changes: 28 additions & 0 deletions test/rrule_from_frule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Test
using ChainRulesDeclarationHelpers
using ChainRulesCore
using ChainRulesTestUtils
using Zygote


@testset "rrule_from_frule" begin
function f(x)
a = sin.(x)
b = sum(a)
c = b * a
return c
end

function ChainRulesCore.frule((Δself, Δx), ::typeof(f), x)
a, ȧ = sin.(x), cos.(x) .* Δx
b, ḃ = sum(a), sum(ȧ)
c, ċ = b * a, ḃ * a + b * ȧ
return c, ċ
end

x = rand(3)
test_frule(f, x)

@rrule_from_frule f(x::AbstractArray{<:Real})
test_rrule(Zygote.ZygoteRuleConfig(), f, x; check_inferred=false)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ using ChainRulesDeclarationHelpers
using Test

@testset "ChainRulesDeclarationHelpers" begin

include("rrule_from_frule.jl")
end