Skip to content

Commit 1873529

Browse files
Merge pull request #19 from mkitti/mkitti/unsafe_wrap
Implement Base.unsafe_wrap(MallocArray, ...)
2 parents 202812a + c5c8974 commit 1873529

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

src/mallocarray.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,24 @@
165165
return c
166166
end
167167

168+
# Avoid Base.unsafe_wrap stack overflow
169+
@inline _unsafe_wrap(::Type{MallocArray{T,N}}, ptr::Ptr, dims::Dims{N}, own::Bool = false) where {T,N} =
170+
own ?
171+
error("A MallocArray cannot be owned via `unsafe_wrap`. Set the keyword `own` to be false.") :
172+
MallocArray{T,N}(ptr, prod(dims), dims)
173+
174+
# Overload Base.unsafe_wrap
175+
"""
176+
```julia
177+
unsafe_wrap(MallocArray, ptr::Ptr{T}, dims)
178+
```
179+
Create a `MallocArray{T}` wrapping around `ptr`
180+
"""
181+
@inline Base.unsafe_wrap(::Union{Type{MallocArray},Type{MallocArray{T}},Type{MallocArray{T,N}}}, ptr::Ptr{T}, dims::Dims{N}; own = false) where {T,N} =
182+
_unsafe_wrap(MallocArray{T,N}, ptr, dims, own)
183+
@inline Base.unsafe_wrap(::Union{Type{MallocArray},Type{MallocArray{T}},Type{MallocArray{T,1}}}, ptr::Ptr{T}, dim::Integer; own = false) where {T} =
184+
_unsafe_wrap(MallocArray{T,1}, ptr, (Int(dim),), own)
185+
168186
# Other custom constructors
169187
"""
170188
```julia

test/testmallocarray.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,28 @@
178178
free(C)
179179
free(D)
180180

181+
## ---
182+
183+
## -- test Base.unsafe_wrap
184+
185+
ptr = StaticTools.malloc(64)
186+
A = unsafe_wrap(MallocArray, Ptr{Int8}(ptr), 64)
187+
A .= 1:64
188+
@test typeof(A) == MallocArray{Int8, 1}
189+
@test A == 1:64
190+
191+
B = unsafe_wrap(MallocArray{Int64}, Ptr{Int64}(ptr), (2,4))
192+
B[:] .= 1:8
193+
@test typeof(B) == MallocArray{Int64,2}
194+
@test B == reshape(1:8, (2,4))
195+
196+
C = unsafe_wrap(MallocArray{Float64,3}, Ptr{Float64}(ptr), (2,2,2))
197+
C[:] .= 1.0:1.0:8.0
198+
@test typeof(C) == MallocArray{Float64,3}
199+
@test C == reshape(1.0:1.0:8.0, (2,2,2))
200+
201+
free(A)
202+
181203
## ---
182204

183205
A = mones(11,10)

0 commit comments

Comments
 (0)