Skip to content

Commit 955886e

Browse files
authored
Merge pull request #7 from ChrisRackauckas/staticarrayinterface
update to StaticArrayInterface
2 parents 173f580 + b20bc37 commit 955886e

File tree

5 files changed

+81
-78
lines changed

5 files changed

+81
-78
lines changed

Project.toml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,17 @@ authors = ["chriselrod <[email protected]> and contributors"]
44
version = "0.1.13"
55

66
[deps]
7-
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8-
ArrayInterfaceOffsetArrays = "015c0d05-e682-4f19-8f0a-679ce4c54826"
9-
ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
107
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
118
ManualMemory = "d125e4d3-2237-4719-b19c-fa641b8a4667"
129
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
1310
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
11+
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
1412

1513
[compat]
16-
ArrayInterface = "3.1.24, 4, 5, 6"
17-
ArrayInterfaceOffsetArrays = "0.1"
18-
ArrayInterfaceStaticArrays = "0.1"
1914
ManualMemory = "0.1.6"
2015
SIMDTypes = "0.1"
21-
Static = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
16+
Static = "0.8"
17+
StaticArrayInterface = "1"
2218
julia = "1.6"
2319

2420
[extras]

src/LayoutPointers.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
module LayoutPointers
2-
if isdefined(Base, :Experimental) &&
3-
isdefined(Base.Experimental, Symbol("@max_methods"))
4-
@eval Base.Experimental.@max_methods 1
2+
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
3+
@eval Base.Experimental.@max_methods 1
54
end
65

7-
using ArrayInterface, Static, LinearAlgebra
8-
using ArrayInterface: CPUPointer, StrideIndex, offsets
9-
using ArrayInterfaceOffsetArrays
10-
using ArrayInterfaceStaticArrays
6+
using Static, LinearAlgebra, StaticArrayInterface
117
using SIMDTypes: Bit, FloatingTypes, IntegerTypesHW
128
using Static: Zero, One
13-
using ArrayInterface:
9+
using StaticArrayInterface:
1410
contiguous_axis,
1511
contiguous_axis_indicator,
1612
contiguous_batch_size,
@@ -20,12 +16,15 @@ using ArrayInterface:
2016
CPUTuple,
2117
static_first,
2218
static_step,
23-
strides
19+
static_strides,
20+
CPUPointer,
21+
StrideIndex,
22+
offsets
2423
using ManualMemory: preserve_buffer, offsetsize
2524

2625
export stridedpointer
2726

28-
const IntegerTypes = Union{IntegerTypesHW, StaticInt}
27+
const IntegerTypes = Union{IntegerTypesHW,StaticInt}
2928

3029
@inline _map(f::F, x::Tuple{}) where {F} = ()
3130
@inline _map(f::F, x::Tuple{X1}) where {F,X1} = (f(getfield(x, 1, false)),)

src/grouped_strided_pointers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535
@inline val_dense_dims(::DensePointerWrapper{D}) where {D} = Val{D}()
3636

3737
@inline Base.pointer(A::DensePointerWrapper) = pointer(getfield(A, :p))
38-
@inline ArrayInterface.StrideIndex(sptr::DensePointerWrapper) =
38+
@inline StaticArrayInterface.StrideIndex(sptr::DensePointerWrapper) =
3939
StrideIndex(getfield(sptr, :p))
4040

4141

@@ -45,7 +45,7 @@ end
4545
DensePointerWrapper{D,T,N,C,B,R,X,O,P}(sp)
4646

4747
@inline _gp_strides(x::StrideIndex) = getfield(x, :strides)
48-
@inline _gp_strides(x) = strides(x)
48+
@inline _gp_strides(x) = static_strides(x)
4949
@inline _gp_strides(::NoStrides) = NoStrides()
5050
grouped_strided_pointer(::Tuple{}, ::Val{()}) = ((), ())
5151

src/stridedpointers.jl

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ end
1111
@inline mulsizeof(::Type{T}, x::Tuple{X1,X2,Vararg}) where {T,X1,X2} =
1212
(mulsizeof(T, getfield(x, 1, false)), mulsizeof(T, Base.tail(x))...)
1313

14-
@inline bytestrides(A::AbstractArray{T}) where {T} = mulsizeof(T, ArrayInterface.strides(A))
14+
@inline bytestrides(A::AbstractArray{T}) where {T} =
15+
mulsizeof(T, StaticArrayInterface.static_strides(A))
1516

16-
@inline memory_reference(A::NTuple) = memory_reference(ArrayInterface.device(A), A)
17-
@inline memory_reference(A::AbstractArray) = memory_reference(ArrayInterface.device(A), A)
17+
@inline memory_reference(A::NTuple) = memory_reference(StaticArrayInterface.device(A), A)
18+
@inline memory_reference(A::AbstractArray) =
19+
memory_reference(StaticArrayInterface.device(A), A)
1820
@inline memory_reference(A::BitArray) = Base.unsafe_convert(Ptr{Bit}, A.chunks), A.chunks
1921
@inline memory_reference(::CPUPointer, A) = pointer(A), preserve_buffer(A)
2022
@inline memory_reference(
@@ -39,34 +41,34 @@ end
3941
@inline function memory_reference_subarray(::PT, A::SubArray) where {PT}
4042
p, m = memory_reference(PT(), parent(A))
4143
pA = parent(A)
42-
offset = ArrayInterface.reduce_tup(
44+
offset = StaticArrayInterface.reduce_tup(
4345
+,
44-
_map(*, _map(ind_diff, A.indices, offsets(pA)), strides(pA)),
46+
_map(*, _map(ind_diff, A.indices, offsets(pA)), static_strides(pA)),
4547
)
4648
p + sizeof(eltype(A)) * offset, m
4749
end
4850
@inline function memory_reference(::CPUTuple, A)
4951
r = Ref(A)
5052
Base.unsafe_convert(Ptr{eltype(A)}, Base.pointer_from_objref(r)), r
5153
end
52-
@inline function memory_reference(::ArrayInterface.CheckParent, A)
54+
@inline function memory_reference(::StaticArrayInterface.CheckParent, A)
5355
P = parent(A)
5456
if P === A
55-
memory_reference(ArrayInterface.CPUIndex(), A)
57+
memory_reference(StaticArrayInterface.CPUIndex(), A)
5658
else
57-
memory_reference(ArrayInterface.device(P), P)
59+
memory_reference(StaticArrayInterface.device(P), P)
5860
end
5961
end
60-
@inline memory_reference(::ArrayInterface.CPUIndex, A) =
62+
@inline memory_reference(::StaticArrayInterface.CPUIndex, A) =
6163
throw("Memory access for $(typeof(A)) not implemented yet.")
6264

63-
@inline ArrayInterface.contiguous_axis(
65+
@inline StaticArrayInterface.contiguous_axis(
6466
::Type{A},
6567
) where {T,N,C,A<:AbstractStridedPointer{T,N,C}} = StaticInt{C}()
66-
@inline ArrayInterface.contiguous_batch_size(
68+
@inline StaticArrayInterface.contiguous_batch_size(
6769
::Type{A},
6870
) where {T,N,C,B,A<:AbstractStridedPointer{T,N,C,B}} = StaticInt{B}()
69-
@inline ArrayInterface.stride_rank(
71+
@inline StaticArrayInterface.stride_rank(
7072
::Type{A},
7173
) where {T,N,C,B,R,A<:AbstractStridedPointer{T,N,C,B,R}} = _map(StaticInt, R)
7274
@inline memory_reference(A::AbstractStridedPointer) = pointer(A), nothing
@@ -92,21 +94,26 @@ end
9294
end
9395
@inline function stridedpointer(A::AbstractArray)
9496
p, r = memory_reference(A)
95-
stridedpointer(p, bytestrideindex(A), ArrayInterface.contiguous_batch_size(A))
97+
stridedpointer(p, bytestrideindex(A), StaticArrayInterface.contiguous_batch_size(A))
9698
end
9799
@inline function stridedpointer_preserve(A::AbstractArray)
98100
p, r = memory_reference(A)
99-
stridedpointer(p, bytestrideindex(A), ArrayInterface.contiguous_batch_size(A)), r
101+
stridedpointer(p, bytestrideindex(A), StaticArrayInterface.contiguous_batch_size(A)), r
100102
end
101103
@inline function stridedpointer_preserve(t::NTuple)
102104
p, r = memory_reference(t)
103-
stridedpointer(p, ArrayInterface.StrideIndex{1,(1,),1}((static(sizeof(eltype(t))),), (static(1),)), static(0)), r
105+
stridedpointer(
106+
p,
107+
StaticArrayInterface.StrideIndex{1,(1,),1}((static(sizeof(eltype(t))),), (static(1),)),
108+
static(0),
109+
),
110+
r
104111
end
105112
@inline val_stride_rank(::AbstractStridedPointer{T,N,C,B,R}) where {T,N,C,B,R} = Val{R}()
106113
@generated val_dense_dims(::AbstractStridedPointer{T,N}) where {T,N} =
107114
Val{ntuple(==(0), Val(N))}()
108115
@inline val_stride_rank(A) = Val(known(stride_rank(A)))
109-
@inline val_dense_dims(A) = Val(known(ArrayInterface.dense_dims(A)))
116+
@inline val_dense_dims(A) = Val(known(StaticArrayInterface.dense_dims(A)))
110117

111118
function zerotupleexpr(N::Int)
112119
t = Expr(:tuple)
@@ -126,7 +133,7 @@ end
126133
)
127134
@inline zstridedpointer(A) = zero_offsets(stridedpointer(A))
128135
@inline function zstridedpointer_preserve(A::AbstractArray{T,N}) where {T,N}
129-
strd = mulsizeof(T, ArrayInterface.strides(A))
136+
strd = mulsizeof(T, StaticArrayInterface.static_strides(A))
130137
si = StrideIndex{N,known(stride_rank(A)),Int(contiguous_axis(A))}(
131138
strd,
132139
zerotuple(Val(N)),
@@ -148,7 +155,7 @@ Base.unsafe_convert(::Type{Ptr{T}}, ptr::AbstractStridedPointer{T}) where {T} =
148155
end
149156

150157
@inline dynamic_offsets(si::StrideIndex{N,R,C}) where {N,R,C} =
151-
StrideIndex{N,R,C}(strides(si), _map(Int, offsets(si)))
158+
StrideIndex{N,R,C}(static_strides(si), _map(Int, offsets(si)))
152159
struct StridedBitPointer{N,C,B,R,X,O} <: AbstractStridedPointer{Bit,N,C,B,R,X,O}
153160
p::Ptr{Bit}
154161
si::StrideIndex{N,R,C,X,O}
@@ -162,18 +169,19 @@ end
162169
StridedBitPointer{N,C,0,R,X,NTuple{N,Int}}(p, dynamic_offsets(si))
163170

164171
@inline Base.pointer(p::Union{StridedPointer,StridedBitPointer}) = getfield(p, :p)
165-
@inline ArrayInterface.StrideIndex(sptr::Union{StridedPointer,StridedBitPointer}) =
172+
@inline StaticArrayInterface.StrideIndex(sptr::Union{StridedPointer,StridedBitPointer}) =
166173
getfield(sptr, :si)
167174
@inline bytestrideindex(sptr::AbstractStridedPointer) = StrideIndex(sptr)
168175

169176
@inline bytestrides(si::StrideIndex) =
170177
_map(Base.Fix2(*, StaticInt{8}()), getfield(si, :strides))
171178
@inline bytestrides(ptr::AbstractStridedPointer) = getfield(StrideIndex(ptr), :strides)
172179
@inline Base.strides(ptr::AbstractStridedPointer) = getfield(StrideIndex(ptr), :strides)
173-
@inline ArrayInterface.strides(ptr::AbstractStridedPointer) =
180+
@inline StaticArrayInterface.static_strides(ptr::AbstractStridedPointer) =
174181
getfield(StrideIndex(ptr), :strides)
175-
@inline ArrayInterface.offsets(ptr::AbstractStridedPointer) = offsets(StrideIndex(ptr))
176-
@inline ArrayInterface.contiguous_axis_indicator(
182+
@inline StaticArrayInterface.offsets(ptr::AbstractStridedPointer) =
183+
offsets(StrideIndex(ptr))
184+
@inline StaticArrayInterface.contiguous_axis_indicator(
177185
ptr::AbstractStridedPointer{T,N,C},
178186
) where {T,N,C} = contiguous_axis_indicator(StaticInt{C}(), Val{N}())
179187

@@ -183,14 +191,14 @@ end
183191
ptr::Ptr,
184192
offset::Tuple,
185193
) where {T,N,C,B,R}
186-
si = StrideIndex{N,R,C}(strides(sptr), offset)
194+
si = StrideIndex{N,R,C}(static_strides(sptr), offset)
187195
stridedpointer(ptr, si, contiguous_batch_size(sptr))
188196
end
189197
@inline function similar_no_offset(
190198
sptr::AbstractStridedPointer{T,N,C,B,R},
191199
ptr::Ptr,
192200
) where {T,N,C,B,R}
193-
si = StrideIndex{N,R,C}(strides(sptr), zerotuple(Val(N)))
201+
si = StrideIndex{N,R,C}(static_strides(sptr), zerotuple(Val(N)))
194202
stridedpointer(ptr, si, contiguous_batch_size(sptr))
195203
end
196204

@@ -203,7 +211,7 @@ end
203211
# s += A[i,i]
204212
# end
205213
# first access is at zero-based index
206-
# (first(6:16) - ArrayInterface.offsets(a)[1]) * ArrayInterface.strides(A)[1] + (first(6:16) - ArrayInterface.offsets(a)[2]) * ArrayInterface.strides(A)[2]
214+
# (first(6:16) - StaticArrayInterface.offsets(a)[1]) * StaticArrayInterface.static_strides(A)[1] + (first(6:16) - StaticArrayInterface.offsets(a)[2]) * StaticArrayInterface.static_strides(A)[2]
207215
# equal to
208216
# (6 - 6)*1 + (6 - 5)*10 = 10
209217
# i.e., the 1-based index 11.
@@ -238,21 +246,21 @@ FastRange{T}(f::F, s::S) where {T<:FloatingTypes,F,S} = FastRange{T,F,S,Int}(f,
238246
# FastRange{T}(f::F,s::S,::False) where {T<:FloatingTypes,F,S} = FastRange{T,F,S,Int32}(f,s,zero(Int32))
239247

240248
@inline function memory_reference(r::AbstractRange{T}) where {T}
241-
s = ArrayInterface.static_step(r)
242-
FastRange{T}(ArrayInterface.static_first(r) - s, s), nothing
249+
s = StaticArrayInterface.static_step(r)
250+
FastRange{T}(StaticArrayInterface.static_first(r) - s, s), nothing
243251
end
244252
@inline memory_reference(r::FastRange) = (r, nothing)
245253
@inline bytestrides(::FastRange{T}) where {T} = (StaticInt(sizeof(T)),)
246-
@inline ArrayInterface.offsets(::FastRange) = (One(),)
254+
@inline StaticArrayInterface.offsets(::FastRange) = (One(),)
247255
@inline val_stride_rank(::FastRange) = Val{(1,)}()
248256
@inline val_dense_dims(::FastRange) = Val{(true,)}()
249-
@inline ArrayInterface.contiguous_axis(::FastRange) = One()
250-
@inline ArrayInterface.contiguous_batch_size(::FastRange) = Zero()
257+
@inline StaticArrayInterface.contiguous_axis(::FastRange) = One()
258+
@inline StaticArrayInterface.contiguous_batch_size(::FastRange) = Zero()
251259

252260
@inline stridedpointer(fr::FastRange, ::StrideIndex, ::StaticInt{0}) = fr
253261
struct NoStrides end
254262
@inline bytestrideindex(::FastRange) = NoStrides()
255-
@inline ArrayInterface.offsets(::NoStrides) = NoStrides()
263+
@inline StaticArrayInterface.offsets(::NoStrides) = NoStrides()
256264
@inline reconstruct_ptr(r::FastRange{T}, o) where {T} =
257265
FastRange{T}(getfield(r, :f), getfield(r, :s), o)
258266

test/runtests.jl

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,31 @@
1-
using LayoutPointers, ArrayInterface, ArrayInterfaceOffsetArrays, Aqua, Test
1+
using LayoutPointers, StaticArrayInterface, Aqua, Test
2+
struct SizedWrapper{M,N,T,AT<:AbstractMatrix{T}} <: AbstractMatrix{T}
3+
A::AT
4+
end
5+
StaticArrayInterface.is_forwarding_wrapper(::Type{<:SizedWrapper}) = true
6+
SizedWrapper{M,N}(A::AT) where {M,N,T,AT<:AbstractMatrix{T}} = SizedWrapper{M,N,T,AT}(A)
7+
Base.size(::SizedWrapper{M,N}) where {M,N} = (M, N)
8+
Base.getindex(A::SizedWrapper, i...) = getindex(parent(A), i...)
9+
Base.parent(dw::SizedWrapper) = dw.A
10+
StaticArrayInterface.parent_type(::Type{SizedWrapper{M,N,T,AT}}) where {M,N,T,AT} = AT
11+
LayoutPointers.memory_reference(dw::SizedWrapper) =
12+
LayoutPointers.memory_reference(parent(dw))
13+
StaticArrayInterface.contiguous_axis(dw::SizedWrapper) =
14+
LayoutPointers.contiguous_axis(parent(dw))
15+
StaticArrayInterface.contiguous_batch_size(dw::SizedWrapper) =
16+
LayoutPointers.contiguous_batch_size(parent(dw))
17+
LayoutPointers.val_stride_rank(dw::SizedWrapper) =
18+
LayoutPointers.val_stride_rank(parent(dw))
19+
function StaticArrayInterface.static_strides(dw::SizedWrapper{M,N,T}) where {M,N,T}
20+
if LayoutPointers.val_stride_rank(dw) === Val((1, 2))
21+
return LayoutPointers.One(), LayoutPointers.StaticInt{M}()
22+
else#if LayoutPointers.val_stride_rank(dw) === Val((2,1))
23+
return LayoutPointers.StaticInt{N}(), LayoutPointers.One()
24+
end
25+
end
26+
StaticArrayInterface.offsets(dw::SizedWrapper) = LayoutPointers.offsets(parent(dw))
27+
LayoutPointers.val_dense_dims(dw::SizedWrapper{T,N}) where {T,N} =
28+
LayoutPointers.val_dense_dims(parent(dw))
229

330
@testset "LayoutPointers.jl" begin
431
Aqua.test_all(LayoutPointers)
@@ -9,33 +36,6 @@ using LayoutPointers, ArrayInterface, ArrayInterfaceOffsetArrays, Aqua, Test
936
A = Matrix{Float64}(undef, M, K)
1037
B = Matrix{Float64}(undef, K, N)
1138
C = Matrix{Float64}(undef, M, N)
12-
struct SizedWrapper{M,N,T,AT<:AbstractMatrix{T}} <: AbstractMatrix{T}
13-
A::AT
14-
end
15-
ArrayInterface.is_forwarding_wrapper(::Type{<:SizedWrapper}) = true
16-
SizedWrapper{M,N}(A::AT) where {M,N,T,AT<:AbstractMatrix{T}} = SizedWrapper{M,N,T,AT}(A)
17-
Base.size(::SizedWrapper{M,N}) where {M,N} = (M, N)
18-
Base.getindex(A::SizedWrapper, i...) = getindex(parent(A), i...)
19-
Base.parent(dw::SizedWrapper) = dw.A
20-
ArrayInterface.parent_type(::Type{SizedWrapper{M,N,T,AT}}) where {M,N,T,AT} = AT
21-
LayoutPointers.memory_reference(dw::SizedWrapper) =
22-
LayoutPointers.memory_reference(parent(dw))
23-
ArrayInterface.contiguous_axis(dw::SizedWrapper) =
24-
LayoutPointers.contiguous_axis(parent(dw))
25-
ArrayInterface.contiguous_batch_size(dw::SizedWrapper) =
26-
LayoutPointers.contiguous_batch_size(parent(dw))
27-
LayoutPointers.val_stride_rank(dw::SizedWrapper) =
28-
LayoutPointers.val_stride_rank(parent(dw))
29-
function ArrayInterface.strides(dw::SizedWrapper{M,N,T}) where {M,N,T}
30-
if LayoutPointers.val_stride_rank(dw) === Val((1, 2))
31-
return LayoutPointers.One(), LayoutPointers.StaticInt{M}()
32-
else#if LayoutPointers.val_stride_rank(dw) === Val((2,1))
33-
return LayoutPointers.StaticInt{N}(), LayoutPointers.One()
34-
end
35-
end
36-
ArrayInterface.offsets(dw::SizedWrapper) = LayoutPointers.offsets(parent(dw))
37-
LayoutPointers.val_dense_dims(dw::SizedWrapper{T,N}) where {T,N} =
38-
LayoutPointers.val_dense_dims(parent(dw))
3939

4040
GC.@preserve A B C begin
4141
fs = (false, true)#[identity, adjoint]

0 commit comments

Comments
 (0)