Skip to content

Commit 496f68b

Browse files
authored
Merge pull request #580 from JuliaParallel/jps/datadeps-no-haswritedep
datadeps: Don't skip copy on no writedep
2 parents 2e155c9 + 96b4f89 commit 496f68b

File tree

6 files changed

+85
-13
lines changed

6 files changed

+85
-13
lines changed

src/array/darray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ function Base.collect(d::DArray; tree=false)
191191
end
192192
end
193193

194+
Base.wait(A::DArray) = foreach(wait, A.chunks)
195+
194196
### show
195197

196198
#= FIXME

src/array/operators.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,17 @@ Base.last(A::DArray) = A[end]
114114

115115
# In-place operations
116116

117+
function imap!(f, A)
118+
for idx in eachindex(A)
119+
A[idx] = f(A[idx])
120+
end
121+
return A
122+
end
123+
117124
function Base.map!(f, a::DArray{T}) where T
118125
Dagger.spawn_datadeps() do
119126
for ca in chunks(a)
120-
Dagger.@spawn map!(f, InOut(ca), ca)
127+
Dagger.@spawn imap!(f, InOut(ca))
121128
end
122129
end
123130
return a

src/array/random.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function Random.rand!(rng::AbstractRNG, A::DArray{T}) where T
99
Dagger.spawn_datadeps() do
1010
for Ac in chunks(A)
1111
rng = randfork(rng, part_sz)
12-
Dagger.@spawn map!(_->rand(rng, T), InOut(Ac), Ac)
12+
Dagger.@spawn imap!(InOut(_->rand(rng, T)), InOut(Ac))
1313
end
1414
end
1515
return A
@@ -19,7 +19,7 @@ function Random.randn!(rng::AbstractRNG, A::DArray{T}) where T
1919
Dagger.spawn_datadeps() do
2020
for Ac in chunks(A)
2121
rng = randfork(rng, part_sz)
22-
Dagger.@spawn map!(_->randn(rng, T), InOut(Ac), Ac)
22+
Dagger.@spawn imap!(InOut(_->randn(rng, T)), InOut(Ac))
2323
end
2424
end
2525
return A

src/datadeps.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,22 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
147147
# The mapping of memory space to remote argument copies
148148
remote_args::Dict{MemorySpace,IdDict{Any,Any}}
149149

150+
# Cache of whether arguments supports in-place move
151+
supports_inplace_cache::IdDict{Any,Bool}
152+
150153
# The aliasing analysis state
151154
alias_state::State
152155

153156
function DataDepsState(aliasing::Bool)
154157
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
155158
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
159+
supports_inplace_cache = IdDict{Any,Bool}()
156160
if aliasing
157161
state = DataDepsAliasingState()
158162
else
159163
state = DataDepsNonAliasingState()
160164
end
161-
return new{typeof(state)}(aliasing, dependencies, remote_args, state)
165+
return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state)
162166
end
163167
end
164168

@@ -168,6 +172,12 @@ function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
168172
end
169173
end
170174

175+
function supports_inplace_move(state::DataDepsState, arg)
176+
return get!(state.supports_inplace_cache, arg) do
177+
return supports_inplace_move(arg)
178+
end
179+
end
180+
171181
# Determine which arguments could be written to, and thus need tracking
172182

173183
"Whether `arg` has any writedep in this datadeps region."
@@ -323,6 +333,30 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t
323333
astate.data_origin[task] = space
324334
end
325335

336+
"""
337+
supports_inplace_move(x) -> Bool
338+
339+
Returns `false` if `x` doesn't support being copied into from another object
340+
like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting
341+
to copy between values which don't support mutation or otherwise don't have an
342+
implemented `move!` and want to skip in-place copies. When this returns
343+
`false`, datadeps will instead perform out-of-place copies for each non-local
344+
use of `x`, and the data in `x` will not be updated when the `spawn_datadeps`
345+
region returns.
346+
"""
347+
supports_inplace_move(x) = true
348+
supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true))
349+
function supports_inplace_move(c::Chunk)
350+
# FIXME: Use MemPool.access_ref
351+
pid = root_worker_id(c.processor)
352+
if pid == myid()
353+
return supports_inplace_move(poolget(c.handle))
354+
else
355+
return remotecall_fetch(supports_inplace_move, pid, c)
356+
end
357+
end
358+
supports_inplace_move(::Function) = false
359+
326360
# Read/write dependency management
327361
function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps)
328362
_get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps)
@@ -677,8 +711,15 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
677711
# Is the data written previously or now?
678712
arg, deps = unwrap_inout(arg)
679713
arg = arg isa DTask ? fetch(arg; raw=true) : arg
680-
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
681-
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
714+
if !type_may_alias(typeof(arg))
715+
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)"
716+
spec.args[idx] = pos => arg
717+
continue
718+
end
719+
720+
# Is the data writeable?
721+
if !supports_inplace_move(state, arg)
722+
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (non-writeable)"
682723
spec.args[idx] = pos => arg
683724
continue
684725
end
@@ -738,7 +779,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
738779
# Validate that we're not accidentally performing a copy
739780
for (idx, (_, arg)) in enumerate(spec.args)
740781
_, deps = unwrap_inout(task_args[idx][2])
741-
if is_writedep(arg, deps, task)
782+
# N.B. We only do this check when the argument supports in-place
783+
# moves, because for the moment, we are not guaranteeing updates or
784+
# write-back of results
785+
if is_writedep(arg, deps, task) && supports_inplace_move(state, arg)
742786
arg_space = memory_space(arg)
743787
@assert arg_space == our_space "($(repr(spec.f)))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space"
744788
end
@@ -750,6 +794,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
750794
arg, deps = unwrap_inout(arg)
751795
arg = arg isa DTask ? fetch(arg; raw=true) : arg
752796
type_may_alias(typeof(arg)) || continue
797+
supports_inplace_move(state, arg) || continue
753798
if queue.aliasing
754799
for (dep_mod, _, writedep) in deps
755800
ainfo = aliasing(astate, arg, dep_mod)
@@ -830,6 +875,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
830875
continue
831876
end
832877

878+
# Skip non-writeable arguments
879+
if !supports_inplace_move(state, arg)
880+
@dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)"
881+
continue
882+
end
883+
833884
# Get the set of writers
834885
ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg)
835886

@@ -877,8 +928,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
877928
for arg in keys(astate.data_origin)
878929
# Is the data previously written?
879930
arg, deps = unwrap_inout(arg)
880-
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps)
881-
@dagdebug nothing :spawn_datadeps "Skipped copy-from (unwritten)"
931+
if !type_may_alias(typeof(arg))
932+
@dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)"
933+
end
934+
935+
# Can the data be written back to?
936+
if !supports_inplace_move(state, arg)
937+
@dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)"
882938
end
883939

884940
# Is the source of truth elsewhere?
@@ -912,7 +968,7 @@ Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or
912968
argument, respectively. These argument dependencies will be used to specify
913969
which tasks depend on each other based on the following rules:
914970
915-
- Dependencies across different arguments are independent; only dependencies on the same argument synchronize with each other ("same-ness" is determined based on `isequal`)
971+
- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other
916972
- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects
917973
- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel
918974
- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies

src/thunk.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,14 @@ function show_thunk(io::IO, t)
570570
end
571571
print(io, ")")
572572
end
573-
Base.show(io::IO, t::Thunk) = show_thunk(io, t)
573+
function Base.show(io::IO, t::Thunk)
574+
lazy_level = parse(Int, get(ENV, "JULIA_DAGGER_SHOW_THUNK_VERBOSITY", "0"))
575+
if lazy_level == 0
576+
show_thunk(io, t)
577+
else
578+
show_thunk(IOContext(io, :lazy_level => lazy_level), t)
579+
end
580+
end
574581
Base.summary(t::Thunk) = repr(t)
575582

576583
inputs(x::Thunk) = x.inputs

test/datadeps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int)
4444
end
4545
error("Task $tid not found in logs")
4646
end
47-
function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=true)
47+
function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false)
4848
g = SimpleDiGraph()
4949
tid_to_v = Dict{Int,Int}()
5050
seen = Set{Int}()
@@ -165,7 +165,7 @@ function test_datadeps(;args_chunks::Bool,
165165
end
166166
tid_1, tid_2 = task_id.(ts)
167167
test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2])
168-
test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2])
168+
test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false)
169169

170170
# R->W Aliasing
171171
ts = []

0 commit comments

Comments
 (0)