Skip to content

*(::AbstractTriangular, ::DimArray) discards dimensions of triangular matrix #1122

@sethaxen

Description

@sethaxen

An UpperTriangular(::AbstractDimMatrix) has DimUnitRange axes. However, multiplying it by a DimVector results in compatibility of the inner dimension not being validated and of the outer dimension being discarded, e.g.

julia> using DimensionalData, LinearAlgebra, Test

julia> A = UpperTriangular(DimMatrix(randn(5, 5), (X, Y)));

julia> B = DimVector(randn(5), Y);

julia> axes(A)
(X(Base.OneTo(5)), Y(Base.OneTo(5)))

julia> A * B  # should this have dim X?5-element DimArray{Float64, 1} ┐
├────────────────────────── dims ┤
   AnonDim Base.OneTo(1)
└────────────────────────────────┘
   1     -0.192035
 #undef   3.22561
 #undef   0.765636
 #undef   2.12322
 #undef   0.239097

julia> A' * B  # should this error?5-element DimArray{Float64, 1} ┐
├────────────────────────── dims ┤
   AnonDim Base.OneTo(1)
└────────────────────────────────┘
   1     0.00633928
 #undef  0.147744
 #undef  0.0511726
 #undef  1.29501
 #undef  1.95363

I wonder if some special-casing can be performed on the following lines to check if the AbstractMatrix has DimUnitRange axes, and if so, to use those:

function _rebuildmul(A::AbstractMatrix, B::AbstractDimVector)
newdata = A * parent(B)
if newdata isa AbstractArray
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(1)),))
else
newdata
end
end
function _rebuildmul(A::AbstractMatrix, B::AbstractDimMatrix)
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B))))
end

Also, for some reason the multiplication is performed twice here, and the AnonDim is given a lookup with a different length than the dimension length. I wonder if the following is more sensible:

function DimensionalData._rebuildmul(A::AbstractMatrix, B::AbstractDimVector)
    ax1, ax2 = axes(A)
    if ax2 isa Dimensions.DimUnitRange
        isstrict = DimensionalData.strict_matmul()
        Dimensions.comparedims(Dimensions.dims(ax2), first(Dimensions.dims(B)); 
            order=isstrict, val=isstrict, length=false
        )
    end
    newdata = A * parent(B)
    if newdata isa AbstractArray
        out_dim = if ax1 isa Dimensions.DimUnitRange
            Dimensions.dims(ax1)
        else
            Dimensions.AnonDim(ax1)
        end
        DimensionalData.rebuild(B, newdata, (out_dim,))
    else
        newdata
    end
end

With this method, we get

julia> @inferred A * B
┌ 5-element DimArray{Float64, 1} ┐
├────────────────────────── dims ┤
   X
└────────────────────────────────┘
 -0.192035
  3.22561
  0.765636
  2.12322
  0.239097

julia> A' * B
ERROR: DimensionMismatch: X and Y dims on the same axis.
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions