-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Description
An UpperTriangular(::AbstractDimMatrix) has DimUnitRange axes. However, multiplying it by a DimVector results in compatibility of the inner dimension not being validated and of the outer dimension being discarded, e.g.
julia> using DimensionalData, LinearAlgebra, Test
julia> A = UpperTriangular(DimMatrix(randn(5, 5), (X, Y)));
julia> B = DimVector(randn(5), Y);
julia> axes(A)
(X(Base.OneTo(5)), Y(Base.OneTo(5)))
julia> A * B # should this have dim X?
┌ 5-element DimArray{Float64, 1} ┐
├────────────────────────── dims ┤
↓ AnonDim Base.OneTo(1)
└────────────────────────────────┘
1 -0.192035
#undef 3.22561
#undef 0.765636
#undef 2.12322
#undef 0.239097
julia> A' * B # should this error?
┌ 5-element DimArray{Float64, 1} ┐
├────────────────────────── dims ┤
↓ AnonDim Base.OneTo(1)
└────────────────────────────────┘
1 0.00633928
#undef 0.147744
#undef 0.0511726
#undef 1.29501
#undef 1.95363I wonder if some special-casing can be performed on the following lines to check if the AbstractMatrix has DimUnitRange axes, and if so, to use those:
DimensionalData.jl/src/array/matmul.jl
Lines 114 to 124 in 6db30de
| function _rebuildmul(A::AbstractMatrix, B::AbstractDimVector) | |
| newdata = A * parent(B) | |
| if newdata isa AbstractArray | |
| rebuild(B, A * parent(B), (AnonDim(Base.OneTo(1)),)) | |
| else | |
| newdata | |
| end | |
| end | |
| function _rebuildmul(A::AbstractMatrix, B::AbstractDimMatrix) | |
| rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B)))) | |
| end |
Also, for some reason the multiplication is performed twice here, and the AnonDim is given a lookup with a different length than the dimension length. I wonder if the following is more sensible:
function DimensionalData._rebuildmul(A::AbstractMatrix, B::AbstractDimVector)
ax1, ax2 = axes(A)
if ax2 isa Dimensions.DimUnitRange
isstrict = DimensionalData.strict_matmul()
Dimensions.comparedims(Dimensions.dims(ax2), first(Dimensions.dims(B));
order=isstrict, val=isstrict, length=false
)
end
newdata = A * parent(B)
if newdata isa AbstractArray
out_dim = if ax1 isa Dimensions.DimUnitRange
Dimensions.dims(ax1)
else
Dimensions.AnonDim(ax1)
end
DimensionalData.rebuild(B, newdata, (out_dim,))
else
newdata
end
endWith this method, we get
julia> @inferred A * B
┌ 5-element DimArray{Float64, 1} ┐
├────────────────────────── dims ┤
↓ X
└────────────────────────────────┘
-0.192035
3.22561
0.765636
2.12322
0.239097
julia> A' * B
ERROR: DimensionMismatch: X and Y dims on the same axis.
...Metadata
Metadata
Assignees
Labels
No labels