-
Notifications
You must be signed in to change notification settings - Fork 97
Description
Continuing https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042/20, here's how I'd implement the code you quoted (IIUC).
Mechanical transformation that almost always works
First of all, here's a trick you can use almost always. If you have this pattern
Threads.@threads for x in xs
i = Threads.threadid()
f(x, i)
endyou can mechanically convert this to
n = cld(length(xs), Threads.nthreads())
@sync for (i, chunk) in enumerate(Iterators.partition(xs, n))
Threads.@spawn for x in chunk
f(x, i)
end
endThis is very likely correct if the loop body f only uses threadid() with arrays allocated only for this parallel loop (e.g., pre-1.3 reduction pattern).
Array reduction with @reduce acc .+= x
I think
Lines 24 to 36 in a1791f6
| # Build the partial density ρk_real for this k-point | |
| ρk_real = [zeros(eltype(basis), basis.fft_size) for it = 1:Threads.nthreads()] | |
| ψnk_real = [zeros(complex(eltype(basis)), basis.fft_size) for it = 1:Threads.nthreads()] | |
| Threads.@threads for n = 1:size(ψk, 2) | |
| ψnk = @views ψk[:, n] | |
| tid = Threads.threadid() | |
| G_to_r!(ψnk_real[tid], basis, kpt, ψnk) | |
| ρk_real[tid] .+= occupation[n] .* abs2.(ψnk_real[tid]) | |
| end | |
| for it = 2:Threads.nthreads() | |
| ρk_real[1] .+= ρk_real[it] | |
| end | |
| ρk_real = ρk_real[1] |
can be re-written using
@floop for n = 1:size(ψk, 2)
@init ψnk_real = zeros(complex(eltype(basis)), basis.fft_size)
ψnk = @views ψk[:, n]
G_to_r!(ψnk_real, basis, kpt, ψnk)
@reduce ρk_real .+= occupation[n] .* abs2.(ψnk_real)
endThis requires FLoops 0.1.12 or above.
Pre-allocated scratch space and TimerOutput
DFTK.jl/src/terms/Hamiltonian.jl
Lines 124 to 148 in a1791f6
| Threads.@threads for iband = 1:n_bands | |
| to = TimerOutput() # Thread-local timer output | |
| tid = Threads.threadid() | |
| ψ_real = H.scratch.ψ_reals[tid] | |
| @timeit to "local+kinetic" begin | |
| G_to_r!(ψ_real, H.basis, H.kpoint, ψ[:, iband]; normalize=false) | |
| ψ_real .*= potential | |
| r_to_G!(Hψ[:, iband], H.basis, H.kpoint, ψ_real; normalize=false) # overwrites ψ_real | |
| Hψ[:, iband] .+= H.fourier_op.multiplier .* ψ[:, iband] | |
| end | |
| if have_divAgrad | |
| @timeit to "divAgrad" begin | |
| apply!((fourier=Hψ[:, iband], real=nothing), | |
| H.divAgrad_op, | |
| (fourier=ψ[:, iband], real=nothing), | |
| ψ_real) # ψ_real used as scratch | |
| end | |
| end | |
| if tid == 1 | |
| merge!(DFTK.timer, to; tree_point=[t.name for t in DFTK.timer.timer_stack]) | |
| end | |
| end |
If you are OK with allocating about nthreads arrays every time executing this code, ψ_real = H.scratch.ψ_reals[tid] can simply be re-written as @init ψ_real = similar(H.scratch.ψ_reals[1]) or something equivalent on the RHS. If you must reuse H.scratch.ψ_reals, the easiest approach probably is to use the mechanical transformation I noted above.
I am not sure why you are throwing away the timer info on the non-primary thread, but, if you want to merge all of them, you can do
@reduce to = merge!(TimerOutput(), to; tree_point=[t.name for t in DFTK.timer.timer_stack])in the loop body and then
merge!(DFTK.timer, to; tree_point=[t.name for t in DFTK.timer.timer_stack])outside to merge all the timer outputs, provided that merge! on TimerOutput acts like the method on Dict.