Skip to content

Commit e34e5ab

Browse files
authored
allow selecting dimensions with predicates (#618)
1 parent f4d51b4 commit e34e5ab

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/Dimensions/primitives.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ julia> dimnum(A, Y)
206206
all(hasdim(x, q1, query...)) || _extradimserror()
207207
_call_primitive(_dimnum, MaybeFirst(), x, q1, query...)
208208
end
209+
@inline dimnum(x, query::Function) =
210+
_call_primitive(_dimnum, MaybeFirst(), x, query)
209211

210-
@inline function _dimnum(f::Function, ds::Tuple, query::Tuple{Vararg{Int}})
211-
query
212-
end
212+
@inline _dimnum(f::Function, ds::Tuple, query::Tuple{Vararg{Int}}) = query
213213
@inline function _dimnum(f::Function, ds::Tuple, query::Tuple)
214214
numbered = map(ds, ntuple(identity, length(ds))) do d, i
215215
rebuild(d, i)
@@ -659,6 +659,12 @@ struct AlwaysTuple end
659659
@inline _call_primitive1(f, t, op::Function, x, query) = _call_primitive1(f, t, op, dims(x), query)
660660
@inline _call_primitive1(f, t, op::Function, x::Nothing) = _dimsnotdefinederror()
661661
@inline _call_primitive1(f, t, op::Function, x::Nothing, query) = _dimsnotdefinederror()
662+
@inline function _call_primitive1(f, t, op::Function, ds::Tuple, query::Function)
663+
selection = foldl(ds; init=()) do acc, d
664+
query(d) ? (acc..., d) : acc
665+
end
666+
_call_primitive1(f, t, op, ds, selection)
667+
end
662668
@inline function _call_primitive1(f, t, op::Function, d::Tuple, query)
663669
ds = dims(query)
664670
isnothing(ds) && _dims_are_not_dims()

test/primitives.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ end
139139
@test (@ballocated $f1($dimz)) == 0
140140

141141
@test dims(da, X()) isa X
142+
@test dims(da, isforward) isa Tuple{<:X,<:Y}
143+
@test dims(da, !isforward) isa Tuple{}
142144
@test dims(da, Z()) isa Nothing
143145
@test (@inferred dims(da, XDim, YDim)) isa Tuple{<:X,<:Y}
144146
@test (@ballocated dims($da, XDim, YDim)) == 0
@@ -171,6 +173,7 @@ end
171173

172174
@testset "commondims" begin
173175
@test commondims(da, X) == (dims(da, X),)
176+
@test commondims(da, x -> x isa X) == (dims(da, X),)
174177
# Dims are always in the base order
175178
@test (@inferred commondims(da, (Y(), X()))) == dims(da, (X, Y))
176179
@test (@ballocated commondims($da, (Y(), X()))) == 0
@@ -205,6 +208,7 @@ end
205208
@testset "dimnum" begin
206209
dims(da)
207210
@test dimnum(da, Y()) == dimnum(da, 2) == 2
211+
@test dimnum(da, Base.Fix2(isa,Y)) == (2,)
208212
@test (@ballocated dimnum($da, Y())) == 0
209213
@test dimnum(da, X) == 1
210214
@test (@ballocated dimnum($da, X)) == 0
@@ -228,6 +232,7 @@ end
228232

229233
@testset "hasdim" begin
230234
@test hasdim(da, X()) == true
235+
@test hasdim(da, isforward) == (true, true)
231236
@test (@ballocated hasdim($da, X())) == 0
232237
@test hasdim(da, Ti) == false
233238
@test (@ballocated hasdim($da, Ti)) == 0
@@ -264,6 +269,7 @@ end
264269
@testset "otherdims" begin
265270
A = DimArray(ones(5, 10, 15), (X, Y, Z));
266271
@test otherdims(A, X()) == dims(A, (Y, Z))
272+
@test otherdims(A, x -> x isa X) == dims(A, (Y, Z))
267273
@test (@ballocated otherdims($A, X())) == 0
268274
@test otherdims(A, Y) == dims(A, (X, Z))
269275
@test otherdims(A, Z) == dims(A, (X, Y))

0 commit comments

Comments
 (0)