Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ ProjectTo(::Any) = identity
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
(::ProjectTo{NoTangent})(dx::AbstractZero) = dx # and this one solves an ambiguity.
(::ProjectTo{NoTangent})(::InplaceableThunk) = NoTangent() # solves ambiguity, #685
(::ProjectTo{NoTangent})(::Thunk) = NoTangent() # solves ambiguity, #685

# Also, any explicit construction with fields, where all fields project to zero, itself
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
Expand Down Expand Up @@ -277,7 +279,7 @@ end
# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
function ProjectTo(x::Ref)
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
return ProjectTo{Tangent{typeof(x)}}(; x=sub)
end
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))
Expand Down
14 changes: 11 additions & 3 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct NoSuperType end

prow = ProjectTo([1im 2 3im])
@test prow(transpose([1, 2, 3 + 4.0im])) == [1 2 3 + 4im]
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
@test prow(adjoint([1, 2, 3 + 5im])) == [1 2 3 - 5im]
@test prow(adjoint([1, 2, 3])) isa Matrix

Expand Down Expand Up @@ -145,7 +145,7 @@ struct NoSuperType end

@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}

@test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero
end

Expand Down Expand Up @@ -376,7 +376,7 @@ struct NoSuperType end

pvec3 = ProjectTo([1, 2, 3])
@test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,)
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
@test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector
end

Expand Down Expand Up @@ -463,4 +463,12 @@ struct NoSuperType end
psymm = ProjectTo(Symmetric(rand(10^3, 10^3)))
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
end

@testset "#685" begin
@test ProjectTo(BitArray([0]))([1.0]) == NoTangent()
@test ProjectTo(BitArray([0]))(@thunk [1.0]) == NoTangent()

it = InplaceableThunk(x -> x + [1], @thunk [1.0])
@test ProjectTo(BitArray([0]))(it) == NoTangent()
end
end
Loading