Skip to content

De-threadiding parallel loops #588

@tkf

Description

@tkf

Continuing https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042/20, here's how I'd implement the code you quoted (IIUC).

Mechanical transformation that almost always works

First of all, here's a trick you can use almost always. If you have this pattern

Threads.@threads for x in xs
    i = Threads.threadid()
    f(x, i)
end

you can mechanically convert this to

n = cld(length(xs), Threads.nthreads())
@sync for (i, chunk) in enumerate(Iterators.partition(xs, n))
    Threads.@spawn for x in chunk
        f(x, i)
    end
end

This is very likely correct if the loop body f only uses threadid() with arrays allocated only for this parallel loop (e.g., pre-1.3 reduction pattern).

Array reduction with @reduce acc .+= x

I think

DFTK.jl/src/densities.jl

Lines 24 to 36 in a1791f6

# Build the partial density ρk_real for this k-point
ρk_real = [zeros(eltype(basis), basis.fft_size) for it = 1:Threads.nthreads()]
ψnk_real = [zeros(complex(eltype(basis)), basis.fft_size) for it = 1:Threads.nthreads()]
Threads.@threads for n = 1:size(ψk, 2)
ψnk = @views ψk[:, n]
tid = Threads.threadid()
G_to_r!(ψnk_real[tid], basis, kpt, ψnk)
ρk_real[tid] .+= occupation[n] .* abs2.(ψnk_real[tid])
end
for it = 2:Threads.nthreads()
ρk_real[1] .+= ρk_real[it]
end
ρk_real = ρk_real[1]

can be re-written using

    @floop for n = 1:size(ψk, 2)
        @init ψnk_real = zeros(complex(eltype(basis)), basis.fft_size)
        ψnk = @views ψk[:, n]
        G_to_r!(ψnk_real, basis, kpt, ψnk)
        @reduce ρk_real .+= occupation[n] .* abs2.(ψnk_real)
    end

This requires FLoops 0.1.12 or above.

Pre-allocated scratch space and TimerOutput

Threads.@threads for iband = 1:n_bands
to = TimerOutput() # Thread-local timer output
tid = Threads.threadid()
ψ_real = H.scratch.ψ_reals[tid]
@timeit to "local+kinetic" begin
G_to_r!(ψ_real, H.basis, H.kpoint, ψ[:, iband]; normalize=false)
ψ_real .*= potential
r_to_G!(Hψ[:, iband], H.basis, H.kpoint, ψ_real; normalize=false) # overwrites ψ_real
Hψ[:, iband] .+= H.fourier_op.multiplier .* ψ[:, iband]
end
if have_divAgrad
@timeit to "divAgrad" begin
apply!((fourier=Hψ[:, iband], real=nothing),
H.divAgrad_op,
(fourier=ψ[:, iband], real=nothing),
ψ_real) # ψ_real used as scratch
end
end
if tid == 1
merge!(DFTK.timer, to; tree_point=[t.name for t in DFTK.timer.timer_stack])
end
end

If you are OK with allocating about nthreads arrays every time executing this code, ψ_real = H.scratch.ψ_reals[tid] can simply be re-written as @init ψ_real = similar(H.scratch.ψ_reals[1]) or something equivalent on the RHS. If you must reuse H.scratch.ψ_reals, the easiest approach probably is to use the mechanical transformation I noted above.

I am not sure why you are throwing away the timer info on the non-primary thread, but, if you want to merge all of them, you can do

@reduce to = merge!(TimerOutput(), to; tree_point=[t.name for t in DFTK.timer.timer_stack])

in the loop body and then

merge!(DFTK.timer, to; tree_point=[t.name for t in DFTK.timer.timer_stack])

outside to merge all the timer outputs, provided that merge! on TimerOutput acts like the method on Dict.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions