Skip to content

Commit c12a83f

Browse files
Improve diag_At_A and diag_At_B efficiency (#288)
* Improve diag_At_A and diag_At_B for vectors * Use dot * Use sum for diag_At_A * Add allocated tests * Fix function signature * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update tests to work with julia 1.3 * Patch release Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent db0269b commit c12a83f

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractGPs"
22
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.5.7"
4+
version = "0.5.8"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/util/common_covmat_ops.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,20 @@ Xt_invA_Y(X::AbstractVecOrMat, A::Cholesky, Y::AbstractVecOrMat) = (A.U' \ X)' *
6060

6161
At_A(A::AbstractVecOrMat) = A'A
6262

63-
diag_At_A(A::AbstractVecOrMat) = vec(sum(abs2.(A); dims=1))
63+
diag_At_A(A::AbstractMatrix) = vec(sum(abs2, A; dims=1))
64+
diag_At_A(A::AbstractVector) = [sum(abs2, A)]
6465

6566
tr_At_A(A::AbstractVecOrMat) = sum(abs2, A)
6667

67-
function diag_At_B(A::AbstractVecOrMat, B::AbstractVecOrMat)
68+
function diag_At_B(A::AbstractMatrix, B::AbstractMatrix)
6869
size(A) == size(B) || throw(
6970
DimensionMismatch(
7071
"A ($(size(A))) and B ($(size(B))) do not have the same dimensions "
7172
),
7273
)
7374
return vec(sum(A .* B; dims=1))
7475
end
76+
diag_At_B(x::AbstractVector, y::AbstractVector) = [dot(x, y)]
7577

7678
diag_Xt_A_X(A::Cholesky, X::AbstractVecOrMat) = diag_At_A(A.U * X)
7779

test/util/common_covmat_ops.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
@test Xt_invA_Y(X, A, Y) X' * (A \ Y)
7777

7878
@test diag_At_A(x) [x'x]
79+
@test (@allocated diag_At_A(x)) <= 96
7980
@test diag_At_A(X) diag(X'X)
8081

8182
@test diag_At_B(x, z) [x'z]
83+
@test (@allocated diag_At_B(x, z)) <= 96
8284
@test diag_At_B(X, Z) diag(X'Z)
8385

8486
@test diag_Xt_A_X(A, x) [Xt_A_X(A, x)]

0 commit comments

Comments
 (0)