diff --git a/Project.toml b/Project.toml index 192427634..34dc8c76a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.72.5" +version = "1.72.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index b79615f7b..9895b7549 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -249,8 +249,10 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu y = map(first, hobbits) num_xs = Val(length(xs)) paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs) - all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths! - But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed.""" + @static if VERSION < v"1.10.0-DEV.1194" + all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mismatched lengths in Julia <1.10! + But its `rrule` does.""" + end function map_pullback(dy_raw) dy = unthunk(dy_raw) # We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass: diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 09479828c..3019c89df 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -253,8 +253,8 @@ end test_rrule(map, make_two_vec, (4.0, 5.0 + 6im), check_inferred=false) test_rrule(map, Multiplier(rand() + im), Tuple(rand(3)), check_inferred=false) - if try map(+, (1,), (2,3)); true catch e; false end - # True when https://github.com/JuliaLang/julia/issues/42216 has been fixed + if VERSION >= v"1.10.0-DEV.1194" + # Mismatched lengths were not allowed before 1.10 test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) end end