Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 102 additions & 68 deletions src/array/copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,81 +24,103 @@ function allocate_copy_buffer(part::Blocks{N}, A::DArray{T,N}) where {T,N}
# FIXME: undef initializer
return zeros(part, T, size(A))
end
function Base.copyto!(B::DArray{T,N}, A::DArray{T,N}) where {T,N}
if size(B) != size(A)
throw(DimensionMismatch("Cannot copy from array of size $(size(A)) to array of size $(size(B))"))

function darray_copyto!(B::DArray{TB,NB}, A::DArray{TA,NA}, Binds=parentindices(B), Ainds=parentindices(A)) where {TB,NB,TA,NA}
Nmax = max(NA, NB)

pad1(x, i) = length(x) < i ? 1 : x[i]
pad1range(x, i) = length(x) < i ? (1:1) : x[i]
pad1range(x::ArrayDomain, i) = length(x.indexes) < i ? (1:1) : x.indexes[i]
padNmax(x) = ntuple(i->pad1range(x, i), Nmax)
padNmax(x::ArrayDomain) = padNmax(x.indexes)

to_range(x::UnitRange) = x
to_range(x::Integer) = x:x
to_range(x::Base.OneTo{Int}) = UnitRange(x)
to_range(x::Base.Slice{Base.OneTo{Int}}) = UnitRange(x)
to_range(::StepRange) = throw(ArgumentError("Non-continuous ranges are not yet supported for DArray copy"))
to_range(x) = throw(ArgumentError("Unsupported range type for DArray copy: $(typeof(x))"))

if any(x->x isa Vector, Binds) || any(x->x isa Vector, Ainds)
# Split the copy into multiple copies
dims_with_vector = findall(x->x[1] isa Vector || x[2] isa Vector, collect(zip(Binds, Ainds)))
Binds_set = Iterators.product(ntuple(i->i in dims_with_vector ? pad1range(Binds, i) : Ref(pad1range(Binds, i)), Nmax)...)
Ainds_set = Iterators.product(ntuple(i->i in dims_with_vector ? pad1range(Ainds, i) : Ref(pad1range(Ainds, i)), Nmax)...)
for (Binds_inner, Ainds_inner) in zip(Binds_set, Ainds_set)
darray_copyto!(B, A, Binds_inner, Ainds_inner)
end
return
end

Bc = B.chunks
Ac = A.chunks
Asd_all = A.subdomains::DomainBlocks{N}
if !all(ntuple(i->length(pad1range(Binds, i)) == length(pad1range(Ainds, i)), Nmax))
throw(DimensionMismatch("Cannot copy from array of size $(size(A)) (indices $Ainds) to array of size $(size(B)) (indices $Binds)"))
end

# Global element ranges
Binds_range = ntuple(i->to_range(pad1range(Binds, i)), Nmax)
Ainds_range = ntuple(i->to_range(pad1range(Ainds, i)), Nmax)

# Global element offsets
Binds_offset = ntuple(i->Binds_range[i].start-1, Nmax)
Ainds_offset = ntuple(i->Ainds_range[i].start-1, Nmax)

# Limited chunk ranges
Bblocksize = ntuple(i->pad1(B.partitioning.blocksize, i), Nmax)
Ablocksize = ntuple(i->pad1(A.partitioning.blocksize, i), Nmax)
Bidx_range = ntuple(i->UnitRange(fld1(Binds_range[i].start, Bblocksize[i]), fld1(Binds_range[i].stop, Bblocksize[i])), Nmax)
Aidx_range = ntuple(i->UnitRange(fld1(Ainds_range[i].start, Ablocksize[i]), fld1(Ainds_range[i].stop, Ablocksize[i])), Nmax)

# Limited chunk indices
Bci = CartesianIndices(Bidx_range)
Aci = CartesianIndices(Aidx_range)

# Per-chunk ranges
Bsd = B.subdomains::DomainBlocks{NB}
Asd = A.subdomains::DomainBlocks{NA}
Bsd_all = collect(reshape(Bsd, ntuple(i->pad1(size(Bsd), i), Nmax)))
Asd_all = collect(reshape(Asd, ntuple(i->pad1(size(Asd), i), Nmax)))

shift_ranges(x::NTuple{N1,UnitRange}, offset::NTuple{N2,Int}) where {N1,N2} =
ntuple(i->UnitRange(x[i].start-offset[i], x[i].stop-offset[i]), Nmax)

Dagger.spawn_datadeps() do
for Bidx in CartesianIndices(Bc)
Bpart = Bc[Bidx]
Bsd = B.subdomains[Bidx]

# Find the first overlapping subdomain of A
if A.partitioning isa Blocks
Aidx = CartesianIndex(ntuple(i->fld1(Bsd.indexes[i].start, A.partitioning.blocksize[i]), N))
else
# Fallback just in case of non-dense partitioning
Aidx = first(CartesianIndices(Ac))
Asd = first(Asd_all)
for dim in 1:N
while Asd.indexes[dim].stop < Bsd.indexes[dim].start
Aidx += CartesianIndex(ntuple(i->i==dim, N))
Asd = Asd_all[Aidx]
end
end
end
Aidx_start = Aidx

# Find the last overlapping subdomain of A
for dim in 1:N
while true
Aidx_next = Aidx + CartesianIndex(ntuple(i->i==dim, N))
if !(Aidx_next in CartesianIndices(Ac))
break
end
Asd_next = Asd_all[Aidx_next]
if Asd_next.indexes[dim].start <= Bsd.indexes[dim].stop
Aidx = Aidx_next
else
break
end
end
end
Aidx_end = Aidx
for Bidx in Bci
Bpart = B.chunks[Bidx]
Bsd_global_raw = padNmax(Bsd_all[Bidx])
Bsd_global_shifted = shift_ranges(Bsd_global_raw, Binds_offset)

# Find the span and set of subdomains of A overlapping Bpart
Aidx_span = Aidx_start:Aidx_end
Asd_view = view(A.subdomains, Aidx_span)
## Find the overlapping subdomains of A
# Calculate start indices based on overlap with Bsd
Asd_global_target = shift_ranges(Bsd_global_shifted, map(-, Ainds_offset))
Aidx_start_vals = ntuple(i->clamp(fld1(Asd_global_target[i].start, Ablocksize[i]), Aidx_range[i].start, Aidx_range[i].stop), Nmax)
Aidx_start = CartesianIndex(Aidx_start_vals)
# Calculate end indices based on overlap with Bsd
Aidx_end_vals = ntuple(i->clamp(fld1(Asd_global_target[i].stop, Ablocksize[i]), Aidx_range[i].start, Aidx_range[i].stop), Nmax)
Aidx_end = CartesianIndex(Aidx_end_vals)

# Copy all overlapping subdomains of A
for Aidx in Aidx_span
Asd = Asd_all[Aidx]
Apart = Ac[Aidx]

# Compute the true range
range_start = CartesianIndex(ntuple(i->max(Bsd.indexes[i].start, Asd.indexes[i].start), N))
range_end = CartesianIndex(ntuple(i->min(Bsd.indexes[i].stop, Asd.indexes[i].stop), N))
range_diff = range_end - range_start

# Compute the offset range into Apart
Asd_start = ntuple(i->Asd.indexes[i].start, N)
Asd_end = ntuple(i->Asd.indexes[i].stop, N)
Arange = range(range_start - CartesianIndex(Asd_start) + CartesianIndex{N}(1),
range_start - CartesianIndex(Asd_start) + CartesianIndex{N}(1) + range_diff)

# Compute the offset range into Bpart
Bsd_start = ntuple(i->Bsd.indexes[i].start, N)
Bsd_end = ntuple(i->Bsd.indexes[i].stop, N)
Brange = range(range_start - CartesianIndex(Bsd_start) + CartesianIndex{N}(1),
range_start - CartesianIndex(Bsd_start) + CartesianIndex{N}(1) + range_diff)

# Perform view copy
Dagger.@spawn copyto_view!(Out(Bpart), Brange, In(Apart), Arange)
for Aidx in Aidx_start:Aidx_end
Apart = A.chunks[Aidx]
Asd_global_raw = padNmax(Asd_all[Aidx])
Asd_global_shifted = shift_ranges(Asd_global_raw, Ainds_offset)

# Compute the global ranges
range_overlap = intersect(CartesianIndices(Bsd_global_shifted), CartesianIndices(Asd_global_shifted))
Brange_start = ntuple(i->Bsd_global_raw[i].start, Nmax)
Arange_start = ntuple(i->Asd_global_raw[i].start, Nmax)
Brange_global = range_overlap .+ CartesianIndex(Binds_offset)
Arange_global = range_overlap .+ CartesianIndex(Ainds_offset)

# Clamp to the selected indices
Brange_global_clamped = intersect(Brange_global, CartesianIndices(Binds_range))
Arange_global_clamped = intersect(Arange_global, CartesianIndices(Ainds_range))

# Compute the local ranges
Brange_local = Brange_global_clamped .- CartesianIndex(Brange_start) .+ CartesianIndex{Nmax}(1)
Arange_local = Arange_global_clamped .- CartesianIndex(Arange_start) .+ CartesianIndex{Nmax}(1)

# Perform local view copy
Dagger.@spawn copyto_view!(Out(Bpart), Brange_local, In(Apart), Arange_local)
end
end
end
Expand All @@ -109,3 +131,15 @@ function copyto_view!(Bpart, Brange, Apart, Arange)
copyto!(view(Bpart, Brange), view(Apart, Arange))
return
end

Base.copyto!(B::DArray{T,N}, A::DArray{T,N}) where {T,N} =
darray_copyto!(B, A)
Base.copyto!(B::DArray{T,N}, A::Array{T,N}) where {T,N} =
darray_copyto!(B, view(A, B.partitioning))
Base.copyto!(B::Array{T,N}, A::DArray{T,N}) where {T,N} =
darray_copyto!(view(B, A.partitioning), A)

StridedDArray{T,N} = Union{<:DArray{T,N}, SubArray{T,N,<:DArray{T,NP}} where NP}

Base.copyto!(B::StridedDArray, A::StridedDArray) =
darray_copyto!(parent(B), parent(A), parentindices(B), parentindices(A))
12 changes: 1 addition & 11 deletions src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N}
treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))
end
end
Array{T,N}(A::DArray{S,N}) where {T,N,S} = convert(Array{T,N}, collect(A))

Base.wait(A::DArray) = foreach(wait, A.chunks)

Expand Down Expand Up @@ -331,17 +332,6 @@ Base.copy(x::DArray{T,N,B,F}) where {T,N,B,F} =
Base.:(/)(x::DArray{T,N,B,F}, y::U) where {T<:Real,U<:Real,N,B,F} =
(x ./ y)::DArray{Base.promote_op(/, T, U),N,B,F}

"""
view(c::DArray, d)

A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s.
"""
function Base.view(c::DArray, d)
subchunks, subdomains = lookup_parts(c, chunks(c), domainchunks(c), d)
d1 = alignfirst(d)
DArray(eltype(c), d1, subdomains, subchunks, c.partitioning, c.concat)
end

function group_indices(cumlength, idxs,at=1, acc=Any[])
at > length(idxs) && return acc
f = idxs[at]
Expand Down
56 changes: 23 additions & 33 deletions src/array/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,5 @@
### getindex

struct GetIndex{T,N} <: ArrayOp{T,N}
input::ArrayOp
idx::Tuple
end

GetIndex(input::ArrayOp, idx::Tuple) =
GetIndex{eltype(input), ndims(input)}(input, idx)

function stage(ctx::Context, gidx::GetIndex)
inp = stage(ctx, gidx.input)

dmn = domain(inp)
idxs = [if isa(gidx.idx[i], Colon)
indexes(dmn)[i]
else
gidx.idx[i]
end for i in 1:length(gidx.idx)]

# Figure out output dimension
view(inp, ArrayDomain(idxs))
end

function size(x::GetIndex)
map(a -> a[2] isa Colon ?
size(x.input, a[1]) : length(a[2]),
enumerate(x.idx)) |> Tuple
end

Base.getindex(c::ArrayOp, idx::ArrayDomain) =
_to_darray(GetIndex(c, indexes(idx)))
Base.getindex(c::ArrayOp, idx...) =
_to_darray(GetIndex(c, idx))

const GETINDEX_CACHE = TaskLocalValue{Dict{Tuple,Any}}(()->Dict{Tuple,Any}())
const GETINDEX_CACHE_SIZE = ScopedValue{Int}(0)
with_index_caching(f, size::Integer=1) = with(f, GETINDEX_CACHE_SIZE=>size)
Expand Down Expand Up @@ -105,6 +72,23 @@ function Base.getindex(A::DArray{T,N}, idxs::Dims{S}) where {T,N,S}
end
error()
end
function Base.getindex(A::DArray, idx...)
inds = to_indices(A, idx)
A_view = view(A, inds...)
nd = length(inds)
sz = ntuple(i->length(inds[i]), nd)
# TODO: Pad out to same number of dims?
part = nd == length(A.partitioning.blocksize) ? A.partitioning : auto_blocks(sz)
B = zeros(part, eltype(A), sz) # FIXME: Use undef initializer
copyto!(B, A_view)
if size(A_view) != sz
# N.B. Base automatically transposes a row vector to a column vector
return DArray(reshape(B, size(A_view)))
end
return B
end
Base.getindex(A::DArray, idx::ArrayDomain) =
getindex(A, indexes(idx)...)

### setindex!

Expand Down Expand Up @@ -148,6 +132,12 @@ function Base.setindex!(A::DArray{T,N}, value, idxs::Dims{S}) where {T,N,S}
end
error()
end
function Base.setindex!(A::DArray, value, idx...)
inds = to_indices(A, idx)
A_view = view(A, inds...)
copyto!(A_view, value)
return value
end

### Allow/disallow scalar indexing

Expand Down
23 changes: 23 additions & 0 deletions test/array/copyto.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
function test_copyto(sz, partA, partB, T)
# Full arrays
A = rand(T, sz...)
DA = distribute(A, partA)
DB = zeros(partB, T, sz...)
copyto!(DB, DA)
@test collect(DB) == collect(DA) == A

# Contiguous SubArrays
A = rand(T, sz...)
DA = distribute(A, partA)
DB = zeros(partB, T, sz...)
copyto!(view(DB, 1:fld1(sz[1], 2), 1:fld1(sz[2], 2)),
view(DA, 1:fld1(sz[1], 2), 1:fld1(sz[2], 2)))
@test collect(DB)[1:fld1(sz[1], 2), 1:fld1(sz[2], 2)] == collect(DA)[1:fld1(sz[1], 2), 1:fld1(sz[2], 2)]
diff = setdiff(CartesianIndices(DB), CartesianIndices((1:fld1(sz[1], 2), 1:fld1(sz[2], 2))))
@test all(collect(DB)[diff] .== 0)

# Non-contiguous SubArrays (currently unsupported)
A = rand(T, sz...)
DA = distribute(A, partA)
DB = zeros(partB, T, sz...)
@test_throws ArgumentError copyto!(view(DB, 1:2:sz[1], 1:2:sz[2]), view(DA, 1:2:sz[1], 1:2:sz[2]))

# Dimension mismatch
A = rand(T, sz...)
DA = distribute(A, partA)
DB = zeros(partB, T, sz...)
@test_throws DimensionMismatch copyto!(DB, view(DA, 1:2:sz[1], 1:2:sz[2]))
end

@testset "T=$T" for T in (Int8, Int32, Int64, Float64, ComplexF64)
Expand Down
45 changes: 43 additions & 2 deletions test/array/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
end
end

@testset "Getindex" begin
@testset "getindex" begin
function test_getindex(x)
X = distribute(x, Blocks(3,3))
@test collect(X[3:8, 2:7]) == x[3:8, 2:7]
ragged_idx = [1,2,9,7,6,2,4,5]
@test collect(X[ragged_idx, 2:7]) == x[ragged_idx, 2:7]
@test collect(X[ragged_idx, reverse(ragged_idx)]) == x[ragged_idx, reverse(ragged_idx)]
ragged_idx = [1,2,9,7,6,2,4,5]
@test collect(X[[2,7,10], :]) == x[[2,7,10], :]
@test collect(X[[], ragged_idx]) == x[[], ragged_idx]
@test collect(X[[], []]) == x[[], []]
Expand Down Expand Up @@ -59,3 +58,45 @@ end
@test collect(setindex(X,1.0, 3:8, 2:7)) == y
@test collect(X) == x
end

@testset "slicing" begin
A = zeros(Blocks(5, 3), 10, 10)
@test A[1:2, 1:2] isa DArray

# Matrix - Vector
for idx in 1:10
A = zeros(Blocks(5, 3), 10, 10)
b = rand(Blocks(2), 10)
Dagger.allowscalar(false) do
A[:, idx] = b
end
@test all(collect(A)[:, idx] .== b)
@test all(collect(A)[:, 1:idx-1] .== 0)
@test all(collect(A)[:, idx+1:end] .== 0)
end

# Matrix - Vector (transposed)
for idx in 1:10
A = zeros(Blocks(5, 3), 10, 10)
b = rand(Blocks(2), 10)
bT = DArray(b')
Dagger.allowscalar(false) do
A[idx, :] = bT
end
@test all(collect(A)[idx, :] .== b)
@test all(collect(A)[1:idx-1, :] .== 0)
@test all(collect(A)[idx+1:end, :] .== 0)
end

# Matrix - Matrix
for idx in 1:9
A = zeros(Blocks(5, 3), 10, 10)
B = rand(Blocks(2, 2), 10, 10)
Dagger.allowscalar(false) do
A[idx:(idx+1), idx:(idx+1)] = view(B, idx:(idx+1), idx:(idx+1))
end
diff = setdiff(CartesianIndices(A), CartesianIndices((idx:(idx+1), idx:(idx+1))))
@test all(collect(A)[diff] .== 0)
@test all(collect(A)[idx:(idx+1), idx:(idx+1)] .== collect(B)[idx:(idx+1), idx:(idx+1)])
end
end