Skip to content

Commit 0d95dba

Browse files
committed
DArray: Always return Array from collect
1 parent 4f91eed commit 0d95dba

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/array/darray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N}
194194
if tree
195195
collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks)))
196196
else
197-
treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))
197+
collect(treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks)))
198198
end
199199
end
200200
Array{T,N}(A::DArray{S,N}) where {T,N,S} = convert(Array{T,N}, collect(A))

src/array/map-reduce.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,18 @@ function stage(ctx::Context, r::MapReduce{T,N}) where {T,N}
6161

6262
# Tree-reduce intermediate reductions
6363
dims_materialized = dims === Colon() ? ntuple(identity, ndims(inp)) : dims
64-
treered_f(op, x, y) = op.(x, y)
64+
function to_array(x, N)
65+
A = Array{typeof(x),N}(undef, ntuple(i->1, N))
66+
A[1] = x
67+
return A
68+
end
69+
to_array(x::Array, N) = x
70+
function treered_f(op, x, y, N)
71+
value = op.(x, y)
72+
return to_array(value, N)
73+
end
6574
thunks = treereducedim(reduced_parts, dims_materialized) do x, y
66-
Dagger.@spawn treered_f(r.op_outer, x, y)
75+
Dagger.@spawn treered_f(r.op_outer, x, y, length(dims_materialized))
6776
end
6877

6978
c = domainchunks(inp)
@@ -86,7 +95,7 @@ _mapreduce_maybesync(f, op_inner, op_outer, x, ::Colon, init) =
8695
_mapreduce_maybesync(f, op_inner, op_outer, x, nothing, init)
8796
function _mapreduce_maybesync(f, op_inner, op_outer, x::DArray{T,N}, dims::Nothing, init) where {T,N}
8897
Dx = _to_darray(MapReduce(f, op_inner, op_outer, x, dims, init))
89-
return collect(Dx)
98+
return only(collect(Dx))
9099
end
91100

92101
function Base.size(r::MapReduce)

0 commit comments

Comments
 (0)