Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export fastindices # Deprecated
include("lazyarray.jl")

include("axis.jl")
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, ViewAxis, FlatAxis
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, Shaped1DAxis, ViewAxis, FlatAxis

include("componentarray.jl")
export ComponentArray, ComponentVector, ComponentMatrix, getaxes, getdata, valkeys
Expand Down
19 changes: 15 additions & 4 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,20 @@ example)
"""
struct ShapedAxis{Shape} <: AbstractAxis{nothing} end
@inline ShapedAxis(Shape) = ShapedAxis{Shape}()
ShapedAxis(::Tuple{<:Int}) = FlatAxis()
# ShapedAxis(::Tuple{<:Int}) = FlatAxis()

struct Shaped1DAxis{Shape} <: AbstractAxis{nothing} end
ShapedAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Shaped1DAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()

const Shape = ShapedAxis

unshape(ax) = ax
unshape(ax::ShapedAxis) = Axis(indexmap(ax))
unshape(ax::Shaped1DAxis) = Axis(indexmap(ax))

Base.size(::ShapedAxis{Shape}) where {Shape} = Shape
Base.size(::Shaped1DAxis{Shape}) where {Shape} = Shape



Expand Down Expand Up @@ -133,9 +139,9 @@ Axis(::Number) = NullAxis()
Axis(::NamedTuple{()}) = FlatAxis()
Axis(x) = FlatAxis()

const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, Shaped1DAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, Shaped1DAxis} where {IdxMap}


Base.merge(axs::Vararg{Axis}) = Axis(merge(indexmap.(axs)...))
Expand All @@ -149,6 +155,10 @@ reindex(i, offset) = i .+ offset
reindex(ax::FlatAxis, _) = ax
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))
function reindex(ax::ViewAxis{OldInds,IdxMap,Ax}, offset) where {OldInds,IdxMap,Ax<:Shaped1DAxis}
NewInds = viewindex(ax) .+ offset
return ViewAxis(NewInds, Ax())
end

# Get AbstractAxis index
@inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx)
Expand All @@ -175,6 +185,7 @@ end

_maybe_view_axis(inds, ax::Axis) = ViewAxis(inds, ax)
_maybe_view_axis(inds, ::NullAxis) = inds[1]
_maybe_view_axis(inds, ax::Union{ShapedAxis,Shaped1DAxis}) = ViewAxis(inds, ax)

struct CombinedAxis{C,A} <: AbstractUnitRange{Int}
component_axis::C
Expand Down
3 changes: 2 additions & 1 deletion src/compat/static_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ end

_maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x)
_maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, vals...) = x

@generated function static_getproperty(ca::ComponentVector, ::Val{s}) where {s}
Expand Down Expand Up @@ -32,4 +33,4 @@ macro static_unpack(expr)
push!(out.args, :($esc_name = static_getproperty($parent_var_name, $(Val(name)))))
end
return out
end
end
9 changes: 7 additions & 2 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ ComponentArray{T}(::UndefInitializer, ax::Axes) where {T,Axes<:Tuple} =

# Entry from data array and AbstractAxis types dispatches to correct shapes and partitions
# then packs up axes into a tuple for inner constructor
ComponentArray(data, ::FlatAxis...) = data
# ComponentArray(data, ::FlatAxis...) = data
ComponentArray(data, ::Union{FlatAxis,Shaped1DAxis}...) = data
ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax)
ComponentArray(data, ax::NotPartitionedAxis...) = ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...)
function ComponentArray(data, ax::AbstractAxis...)
Expand Down Expand Up @@ -166,6 +167,10 @@ function make_idx(data, nt::Union{NamedTuple, AbstractDict}, last_val)
)...)
return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs))
end
function make_idx(data, nt::NamedTuple{(), Tuple{}}, last_val)
out = last_index(last_val) .+ (1:length(nt))
return (data, ViewAxis(out, ShapedAxis((length(nt),))))
end
function make_idx(data, pair::Pair, last_val)
data, ax = make_idx(data, pair.second, last_val)
len = recursive_length(data)
Expand Down Expand Up @@ -232,7 +237,7 @@ end
# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
shapes = filter_by_type(ShapedAxis, axs...) .|> size
shapes = filter_by_type(Union{ShapedAxis,Shaped1DAxis}, axs...) .|> size
shapes = reduce((tup, s) -> (tup..., s...), shapes)
return reshape(data, shapes)
end
Expand Down
4 changes: 3 additions & 1 deletion src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ struct ComponentIndex{Idx, Ax<:AbstractAxis}
ax::Ax
end
ComponentIndex(idx) = ComponentIndex(idx, FlatAxis())
ComponentIndex(idx::CartesianIndex) = ComponentIndex(idx, ShapedAxis((1,)))
ComponentIndex(idx::AbstractArray{<:Integer}) = ComponentIndex(idx, ShapedAxis(size(idx)))
ComponentIndex(idx::Int) = ComponentIndex(idx, NullAxis())
ComponentIndex(vax::ViewAxis{Inds,IdxMap,Ax}) where {Inds,IdxMap,Ax} = ComponentIndex(Inds, vax.ax)

Expand Down Expand Up @@ -44,4 +46,4 @@ function _getindex_keep(ax::AbstractAxis, sym::Symbol)
end
new_ax = reindex(new_ax, -first(idx)+1)
return ComponentIndex(idx, new_ax)
end
end
2 changes: 2 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Base.show(io::IO, ::PartitionedAxis{PartSz, IdxMap, Ax}) where {PartSz, IdxMap,

Base.show(io::IO, ::ShapedAxis{Shape}) where {Shape} =
print(io, "ShapedAxis($Shape)")
Base.show(io::IO, ::Shaped1DAxis{Shape}) where {Shape} =
print(io, "Shaped1DAxis($Shape)")

Base.show(io::IO, ::MIME"text/plain", ::ViewAxis{Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} =
print(io, "ViewAxis($Inds, $(Ax()))")
Expand Down
39 changes: 30 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ using OffsetArrays
using Test
using Unitful

# Convert abstract unit range to a ViewAxis with ShapeAxis.
r2v(r::AbstractUnitRange) = ViewAxis(r, ShapedAxis(size(r)))

## Test setup
c = (a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45])
nt = (a = 100, b = [4, 1.3], c = c)
nt2 = (a = 5, b = [(a = (a = 20, b = 1), b = 0), (a = (a = 33, b = 1), b = 0)], c = (a = (a = 2, b = [1, 2]), b = [1.0 2.0; 5 6]))

ax = Axis(a = 1, b = 2:3, c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7)))
ax_c = (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7)
ax = Axis(a = 1, b = r2v(2:3), c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I guess I don't understand why we'd need to wrap this in a new type. In the example you gave in the issue, b is a vector element, not an nx1 matrix. I don't think we should introduce a new ShapedAxis1d type if all vector elements are behaving incorrectly. We should just fix the issue of adjoint/transposition that's broken directly. So I guess specifically, the ShapedAxis needs to be created from the FlatAxis during the adjoint operation, not beforehand. And for that, we can use the normal ShapedAxis without having to introduce a new type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie My bad, just saw this comment.
Hmm... not sure I follow, but fwiw I think everything is fine with the transpose—the issue comes up later on.
The problem with the test case in issue #249, aka

ni = 4
nj = 2
nk = 3

X = ComponentVector(a=0.0, b=zeros(Float64, nk), c=zeros(Float64, ni, nj))
Y = ComponentVector(d=zeros(Float64, ni, nj))
J = Y.*X'

appears to be due to the fact that the shape of a unit range (which is used for vector components like X[:b] in the example) isn't understood by ComponentArrays.maybe_reshape. The error I get:

julia> J[:d, :b]
ERROR: DimensionMismatch: new dimensions (4, 2) must be consistent with array size 24
Stacktrace:
 [1] (::Base.var"#throw_dmrsa#328")(dims::Tuple{Int64, Int64}, len::Int64)
   @ Base ./reshapedarray.jl:41
 [2] reshape
   @ ./reshapedarray.jl:45 [inlined]
 [3] maybe_reshape
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/componentarray.jl:237 [inlined]
 [4] ComponentArray
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/componentarray.jl:52 [inlined]
 [5] macro expansion
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:0 [inlined]
 [6] _getindex(::typeof(getindex), ::ComponentMatrix{Float64, Matrix{Float64}, Tuple{Axis{…}, Axis{…}}}, ::Val{:d}, ::Val{:b})
   @ ComponentArrays ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:119
 [7] getindex
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:103 [inlined]
 [8] getindex(::ComponentMatrix{Float64, Matrix{Float64}, Tuple{Axis{…}, Axis{…}}}, ::Symbol, ::Symbol)
   @ ComponentArrays ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:102
 [9] top-level scope
   @ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.

shell> 

In maybe_reshape:

# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
    shapes = filter_by_type(ShapedAxis, axs...) .|> size
    shapes = reduce((tup, s) -> (tup..., s...), shapes)
    return reshape(data, shapes)
end

during J[:d, :b], axs is

axs = (ShapedAxis((4, 2)), FlatAxis())

i.e. the unit range is converted to a FlatAxis, which is then ignored when defining shapes in maybe_reshape.

The axes of the transpose X' look fine to me:

julia> getaxes(X')
(FlatAxis(), Axis(a = 1, b = 2:4, c = ViewAxis(5:12, ShapedAxis((4, 2)))))

julia> 

The point of the Shaped1DAxis is to have a type that can be a part of the NotShapedOrPartitionedAxis Union defined in axis.jl.
This allows it to skip the problematic and unnecessary maybe_reshape for vector components.
If there was a way to dispatch on ShapedAxis{Shape} when Shape was a length-1 Tuple then I don't think we'd need Shaped1DAxis, but I don't know of a way to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping re: this PR. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping re: this PR.

ax_c = (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7))

a = Float64[100, 4, 1.3, 1, 1, 4.4, 0.4, 2, 1, 45]
sq_mat = collect(reshape(1:9, 3, 3))
Expand All @@ -36,6 +38,9 @@ caa = ComponentArray(a = ca, b = sq_mat)

_a, _b, _c = Val.((:a, :b, :c))

ca3 = ComponentArray(a=1, b=[2, 3, 4, 5], c=reshape(6:11, 3, 2))
cmat3 = ca3 .* ca3'
cmat3check = (1:11) .* (1:11)'

## Tests
@testset "Allocations and Inference" begin
Expand Down Expand Up @@ -285,6 +290,17 @@ end
# OffsetArrays' type piracy without introducing type piracy
# ourselves because `() isa Tuple{N, <:CombinedAxis} where {N}`
# @test reshape(a, axes(ca)...) isa Vector{Float64}

# Issue #248: Indexing ComponentMatrix with FlatAxis components
@test cmat3[:a, :a] == cmat3check[1, 1]
@test cmat3[:a, :b] == cmat3check[1, 2:5]
@test cmat3[:a, :c] == reshape(cmat3check[1, 6:11], 3, 2)
@test cmat3[:b, :a] == cmat3check[2:5, 1]
@test cmat3[:b, :b] == cmat3check[2:5, 2:5]
@test cmat3[:b, :c] == reshape(cmat3check[2:5, 6:11], 4, 3, 2)
@test cmat3[:c, :a] == reshape(cmat3check[6:11, 1], 3, 2)
@test cmat3[:c, :b] == reshape(cmat3check[6:11, 2:5], 3, 2, 4)
@test cmat3[:c, :c] == reshape(cmat3check[6:11, 6:11], 3, 2, 3, 2)
end

@testset "Set" begin
Expand Down Expand Up @@ -327,7 +343,7 @@ end
temp = deepcopy(cmat)
@test all((temp[:c, :c][:a, :a] .= 0) .== 0)

A = ComponentArray(zeros(Int, 4, 4), Axis(x = 1:4), Axis(x = 1:4))
A = ComponentArray(zeros(Int, 4, 4), Axis(x = r2v(1:4)), Axis(x = r2v(1:4)))
A[1, :] .= 1
@test A[1, :] == ComponentVector(x = ones(Int, 4))
end
Expand All @@ -337,12 +353,17 @@ end
ca = ComponentArray(a = 1, b = 2, c = [3, 4], d = (a = [5, 6, 7], b = 8))
cmat = ca * ca'

cidx = reshape((1:(2*3)) .+ 2, 2, 3)
ca2 = ComponentArray(a = 1, b = 2, c = cidx, d = (a = [9, 10, 11], b = 12))

@testset "ComponentIndex" begin
ax = getaxes(ca)[1]
@test ax[:a] == ax[1] == ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, FlatAxis())
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = 1:3, b = 4))
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = 2:3))
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))
ax2 = getaxes(ca2)[1]
@test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))
end

@testset "KeepIndex" begin
Expand All @@ -353,14 +374,14 @@ end

@test ca[KeepIndex(1:2)] == ComponentArray(a = 1, b = 2)
@test ca[KeepIndex(1:3)] == ComponentArray([1, 2, 3], Axis(a = 1, b = 2)) # Drops c axis
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = 2:3))
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = r2v(2:3)))
@test ca[KeepIndex(3:end)] == ComponentArray(c = [3, 4], d = (a = [5, 6, 7], b = 8))

@test ca[KeepIndex(:)] == ca

@test cmat[KeepIndex(:a), KeepIndex(:b)] == ComponentArray(fill(2, 1, 1), Axis(a = 1), Axis(b = 1))
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = 1:2))
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = 2:3), FlatAxis())
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = r2v(1:2)))
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = r2v(2:3)), ShapedAxis(size(1:2)))
@test cmat[KeepIndex(2), KeepIndex(3)] == ComponentArray(fill(2 * 3, 1, 1), Axis(b = 1), FlatAxis())
@test cmat[KeepIndex(2), 3] == ComponentArray(b = 2 * 3)
end
Expand Down