Skip to content

Commit 49d9ac7

Browse files
committed
Rewrite driver function using pmap
old : 744.331 ms (8252 allocations: 28.92 MiB) new : 667.085 ms (6791 allocations: 28.82 MiB)
1 parent bfb9424 commit 49d9ac7

File tree

3 files changed

+107
-99
lines changed

3 files changed

+107
-99
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
version:
16-
- '1.6'
17-
- '1.1'
16+
- '1.10'
1817
- '1'
1918
# - 'nightly'
2019
os:

src/RegisterDriver.jl

Lines changed: 104 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -50,98 +50,106 @@ worker has been written to look for such settings:
5050
5151
which will save `extra` only if `:extra` is a key in `mon`.
5252
"""
53-
function driver(outfile::AbstractString, algorithm::Vector, img, mon::Vector)
54-
nworkers = length(algorithm)
55-
length(mon) == nworkers || error("Number of monitors must equal number of workers")
56-
use_workerprocs = nworkers > 1 || workerpid(algorithm[1]) != myid()
57-
rralgorithm = Array{RemoteChannel}(undef, nworkers)
58-
if use_workerprocs
59-
# Push the algorithm objects to the worker processes. This elminates
60-
# per-iteration serialization penalties, and ensures that any
61-
# initalization state is retained.
62-
for i = 1:nworkers
63-
alg = algorithm[i]
64-
rralgorithm[i] = put!(RemoteChannel(workerpid(alg)), alg)
53+
function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)
54+
numworkers = length(algorithms)
55+
length(mon) == numworkers || error("Number of monitors must equal number of algorithms")
56+
use_workerprocs = numworkers > 1 || workerpid(algorithms[1]) != myid()
57+
pool = use_workerprocs ? workers() : [myid()]
58+
wpool = CachingPool(pool) # worker pool for pmap
59+
60+
# Map worker ID to algorithm index
61+
aindices = Dict(map((alg,aidx)->(alg.workerpid=>aidx), algorithms, 1:length(algorithms))...)
62+
63+
# define 'walg' on each worker (algorithms[i].workerpid) with a algorithms[i] value
64+
# And, initialize algorithms on workers
65+
println("Initializing algorithm on workers")
66+
@sync for i = 1:numworkers
67+
p = workerpid(algorithms[i])
68+
use_workerprocs || (p = 1)
69+
alg = algorithms[i]
70+
@spawnat p begin
71+
global walg = alg
72+
init!(walg)
6573
end
66-
# Perform any needed worker initialization
67-
@sync for i = 1:nworkers
68-
p = workerpid(algorithm[i])
69-
@async remotecall_fetch(init!, p, rralgorithm[i])
70-
end
71-
else
72-
init!(algorithm[1])
7374
end
74-
try
75+
76+
println("Working on algorithm and saving the result")
77+
jldopen(outfile, "w") do file
78+
dsets = Dict{Symbol,Any}()
79+
firstsave = Ref(true)
80+
have_unpackable = Ref(false)
7581
n = nimages(img)
76-
fs = FormatSpec("0$(ndigits(n))d") # group names of unpackable objects
77-
jldopen(outfile, "w") do file
78-
dsets = Dict{Symbol,Any}()
79-
firstsave = SharedArray{Bool}(1)
80-
firstsave[1] = true
81-
have_unpackable = SharedArray{Bool}(1)
82-
have_unpackable[1] = false
83-
# Run the jobs
84-
nextidx = 0
85-
getnextidx() = nextidx += 1
86-
writing_mutex = RemoteChannel()
87-
@sync begin
88-
for i = 1:nworkers
89-
alg = algorithm[i]
90-
@async begin
91-
while (idx = getnextidx()) <= n
92-
if use_workerprocs
93-
remotecall_fetch(println, workerpid(alg), "Worker ", workerpid(alg), " is working on ", idx)
94-
# See https://github.com/JuliaLang/julia/issues/22139
95-
tmp = remotecall_fetch(worker, workerpid(alg), rralgorithm[i], img, idx, mon[i])
96-
copy_all_but_shared!(mon[i], tmp)
97-
else
98-
println("Working on ", idx)
99-
mon[1] = worker(algorithm[1], img, idx, mon[1])
100-
end
101-
# Save the results
102-
put!(writing_mutex, true) # grab the lock
103-
try
104-
local g
105-
if firstsave[]
106-
firstsave[] = false
107-
have_unpackable[] = initialize_jld!(dsets, file, mon[i], fs, n)
108-
end
109-
if fetch(have_unpackable[])
110-
g = file[string("stack", fmt(fs, idx))]
111-
end
112-
for (k,v) in mon[i]
113-
if isa(v, Number)
114-
dsets[k][idx] = v
115-
continue
116-
elseif isa(v, Array) || isa(v, SharedArray)
117-
vw = nicehdf5(v)
118-
if eltype(vw) <: BitsType
119-
colons = [Colon() for i = 1:ndims(vw)]
120-
dsets[k][colons..., idx] = vw
121-
continue
122-
end
123-
end
124-
g[string(k)] = v
125-
end
126-
finally
127-
take!(writing_mutex) # release the lock
128-
end
82+
fs = FormatSpec("0$(ndigits(n))d")
83+
84+
# Channel for passing results from workers to master
85+
results_ch = RemoteChannel(()->Channel{Tuple{Int,Dict}}(32), myid())
86+
87+
# Writer task (runs on master)
88+
writer_task = @async begin
89+
while true
90+
data = try
91+
take!(results_ch)
92+
catch
93+
break
94+
end
95+
movidx, monres = data
96+
97+
# Initialize datasets on first save
98+
if firstsave[]
99+
firstsave[] = false
100+
have_unpackable[] = initialize_jld!(dsets, file, monres, fs, n)
101+
end
102+
103+
g = have_unpackable[] ? file[string("stack", fmt(fs, movidx))] : nothing
104+
105+
# Write all values into the file
106+
for (k,v) in monres
107+
# isa(v, SharedArray) && (@show k)
108+
if isa(v, Number)
109+
dsets[k][movidx] = v
110+
elseif isa(v, Array) || isa(v, SharedArray)
111+
vw = nicehdf5(v)
112+
if eltype(vw) <: BitsType
113+
colons = [Colon() for _=1:ndims(vw)]
114+
dsets[k][colons..., movidx] = vw
115+
else
116+
g[string(k)] = v
129117
end
118+
else
119+
g[string(k)] = v
130120
end
131121
end
122+
# yield() # briefly yield control between @async iterations
132123
end
133124
end
134-
finally
135-
# Perform any needed worker cleanup
136-
if use_workerprocs
137-
@sync for i = 1:nworkers
138-
p = workerpid(algorithm[i])
139-
@async remotecall_fetch(close!, p, rralgorithm[i])
140-
end
141-
else
142-
close!(algorithm[1])
125+
126+
# Main computation with pmap
127+
pmap(wpool, 1:n) do movidx
128+
wid = myid()
129+
println("Worker $wid processing $movidx")
130+
131+
# Perform computation
132+
tmp = worker(walg, img, movidx, mon[aindices[wid]])
133+
134+
# Send result back to master for writing
135+
put!(results_ch, (movidx, tmp))
136+
!use_workerprocs && yield() # this needed if single process
137+
return nothing
143138
end
139+
140+
# Close channel and wait for writer to finish
141+
close(results_ch) # This will cause take!(results_ch) throw an error
142+
wait(writer_task)
144143
end
144+
145+
# Closing algorithms on workers
146+
println("Closing algorithms on Workers")
147+
pmap(wpool, 1:numworkers) do _
148+
close!(walg)
149+
return nothing
150+
end
151+
152+
return nothing
145153
end
146154

147155
driver(outfile::AbstractString, algorithm::AbstractWorker, img, mon::Dict) = driver(outfile, [algorithm], img, [mon])
@@ -214,20 +222,20 @@ end
214222

215223
mm_package_loader(algorithm::AbstractWorker) = mm_package_loader([algorithm])
216224
function mm_package_loader(algorithms::Vector)
217-
nworkers = length(algorithms)
218-
use_workerprocs = nworkers > 1 || workerpid(algorithms[1]) != myid()
219-
rrdev = Array{RemoteChannel}(undef, nworkers)
220-
if use_workerprocs
221-
for i = 1:nworkers
222-
dev = algorithms[i].dev
223-
rrdev[i] = put!(RemoteChannel(workerpid(algorithms[i])), dev)
224-
end
225-
@sync for i = 1:nworkers
226-
p = workerpid(algorithms[i])
227-
@async remotecall_fetch(load_mm_package, p, rrdev[i])
228-
end
229-
else
230-
load_mm_package(algorithms[1].dev)
225+
numworkers = length(algorithms)
226+
use_workerprocs = numworkers > 1 || workerpid(algorithms[1]) != myid()
227+
pool = use_workerprocs ? workers() : [myid()]
228+
wpool = CachingPool(pool) # worker pool for pmap
229+
230+
# wdev is defined on each worker
231+
@sync for i = 1:numworkers
232+
p = workerpid(algorithms[i])
233+
dev = algorithms[i].dev
234+
@spawnat p eval(:(global wdev = $dev))
235+
end
236+
237+
pmap(wpool, 1:numworkers) do _
238+
load_mm_package(wdev)
231239
end
232240
nothing
233241
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using Test, Distributed, SharedArrays
22
using ImageCore, JLD
3-
using RegisterDriver, RegisterWorkerShell
3+
using RegisterWorkerShell
44
using AxisArrays: AxisArray
55

66
driverprocs = addprocs(2)
77
push!(LOAD_PATH, pwd())
88
@sync for p in driverprocs
99
@spawnat p push!(LOAD_PATH, pwd())
1010
end
11+
@everywhere using RegisterDriver
1112
using WorkerDummy
1213

1314
workdir = tempname()

0 commit comments

Comments
 (0)