Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.25.0"
version = "1.25.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
16 changes: 15 additions & 1 deletion src/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Disable thunks for 2nd order AD.
_usethunks() = true
rrule(::typeof(_usethunks)) = false, (NoTangent(),)

abstract type AbstractThunk <: AbstractTangent end

struct MutateThunkException <: Exception end
Expand Down Expand Up @@ -141,7 +145,11 @@ macro thunk(body)
# Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined.
# so we get useful stack traces if it errors.
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
Comment on lines 145 to 147
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could go further and generate a name for the thunk, based on source location, instead of anon. () ->. I think the name will be seen sometimes where the code line info is not?

But not this PR perhaps...

return :(Thunk($(esc(func))))
return quote
$(esc(_usethunks))() ?
Thunk($(esc(func))) :
$(esc(func))()
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments on this approach over :(_usethunks() ? Thunk($(esc(func))) : $(esc(body))), from #568? My hope there was that including the body directly, instead of making and calling a function, might be slightly simpler for e.g. Zygote to understand.

end

"""
Expand Down Expand Up @@ -233,6 +241,12 @@ and destroy its inplacability.
struct InplaceableThunk{T<:Thunk,F} <: AbstractThunk
add!::F
val::T

function InplaceableThunk(add!::F, val::T) where {F, T}
_usethunks() ?
new{T, F}(add!, val) :
val
end
end

unthunk(x::InplaceableThunk) = unthunk(x.val)
Expand Down
Loading