Skip to content

Commit cd65401

Browse files
committed
Fix and improve tests
1 parent b65a516 commit cd65401

File tree

9 files changed

+78
-101
lines changed

9 files changed

+78
-101
lines changed

test/dual.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Random.seed!(100)
2424

2525
# create random cost matrix
2626
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
27-
ε = 0.01
27+
ε = 0.1
2828
K = exp.(-C / ε)
2929

3030
@testset "semidual_grad" begin

test/entropic/sinkhorn_barycenter.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Random.seed!(100)
2121
C = pairwise(SqEuclidean(), support'; dims=2)
2222

2323
# regularisation parameter
24-
ε = 0.05
24+
ε = 0.1
2525

2626
# weights
2727
w = ones(N) / N
@@ -30,9 +30,8 @@ Random.seed!(100)
3030
α = sinkhorn_barycenter(μ, C, ε, w, SinkhornGibbs())
3131

3232
# compare with POT
33-
# need to use a larger tolerance here because of a quirk with the POT solver
34-
α_pot = POT.barycenter(μ, C, ε; weights=w, stopThr=1e-9)
35-
@test α α_pot rtol = 1e-6
33+
α_pot = POT.barycenter(μ, C, ε; weights=w, stopThr=1e-16)
34+
@test α α_pot
3635
end
3736

3837
# different element type
@@ -43,7 +42,7 @@ Random.seed!(100)
4342
w32 = map(Float32, w)
4443
α = sinkhorn_barycenter(μ32, C32, ε32, w32, SinkhornGibbs())
4544

46-
α_pot = POT.barycenter(μ32, C32, ε32; weights=w32, stopThr=1e-9)
47-
@test α α_pot rtol = 1e-5
45+
α_pot = POT.barycenter(μ32, C32, ε32; weights=w32, stopThr=1e-16)
46+
@test α α_pot
4847
end
4948
end

test/entropic/sinkhorn_divergence.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Random.seed!(100)
2222
C = pairwise(SqEuclidean(), x)
2323
f(x; μ, σ) = exp(-((x - μ) / σ)^2)
2424
# regularization parameter
25-
ε = 0.05
25+
ε = 0.1
2626
@testset "basic" begin
2727
μ = normalize!(f.(x; μ=0, σ=0.5), 1)
2828
M = 100
@@ -41,7 +41,7 @@ Random.seed!(100)
4141
ν_all,
4242
)
4343

44-
@test loss loss_ rtol = 1e-6
44+
@test loss loss_
4545
@test all(loss .≥ 0)
4646
@test sinkhorn_divergence(μ, μ, C, ε) 0 atol = 1e-9
4747
end
@@ -69,7 +69,7 @@ Random.seed!(100)
6969
end
7070
end
7171
@testset "AD" begin
72-
ε = 0.05
72+
ε = 0.1
7373
μ = normalize!(f.(x; μ=-0.5, σ=0.5), 1)
7474
ν = normalize!(f.(x; μ=0.5, σ=0.5), 1)
7575
for Diff in [ForwardDiff, ReverseDiff]

test/entropic/sinkhorn_epsscaling.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,30 @@ Random.seed!(100)
2626
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
2727

2828
# regularization parameter
29-
ε = 0.01
29+
ε = 0.1
3030

3131
@testset "example" begin
3232
# compute optimal transport plan and cost with POT
33-
γ_pot = POT.sinkhorn(μ, ν, C, ε; method="sinkhorn_stabilized", numItermax=5_000)
34-
c_pot = POT.sinkhorn2(μ, ν, C, ε; method="sinkhorn_stabilized", numItermax=5_000)[1]
33+
γ_pot = POT.sinkhorn(μ, ν, C, ε; method="sinkhorn_stabilized", stopThr=1e-16)
34+
c_pot = POT.sinkhorn2(μ, ν, C, ε; method="sinkhorn_stabilized", stopThr=1e-16)[1]
3535

3636
for alg in (SinkhornGibbs(), SinkhornStabilized())
3737
# compute optimal transport plan and cost
38-
γ = sinkhorn(μ, ν, C, ε, SinkhornEpsilonScaling(alg); maxiter=5_000)
39-
c = sinkhorn2(μ, ν, C, ε, SinkhornEpsilonScaling(alg); maxiter=5_000)
38+
γ = sinkhorn(μ, ν, C, ε, SinkhornEpsilonScaling(alg))
39+
c = sinkhorn2(μ, ν, C, ε, SinkhornEpsilonScaling(alg))
4040

4141
# check that plan and cost are consistent
4242
@test c dot(γ, C)
4343

4444
# compare with Sinkhorn algorithm without ε-scaling
45-
γ_wo_epsscaling = sinkhorn(μ, ν, C, ε, alg; maxiter=5_000)
46-
c_wo_epsscaling = sinkhorn2(μ, ν, C, ε, alg; maxiter=5_000)
47-
@test γ γ_wo_epsscaling rtol = 1e-6
45+
γ_wo_epsscaling = sinkhorn(μ, ν, C, ε, alg)
46+
c_wo_epsscaling = sinkhorn2(μ, ν, C, ε, alg)
47+
@test γ γ_wo_epsscaling
4848
@test c c_wo_epsscaling
4949

5050
# compare with POT
51-
@test γ γ_pot rtol = 1e-6
52-
@test c c_pot rtol = 1e-6
51+
@test γ γ_pot
52+
@test c c_pot
5353
end
5454
end
5555

test/entropic/sinkhorn_gibbs.jl

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,32 @@ Random.seed!(100)
2727
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
2828

2929
# regularization parameter
30-
ε = 0.01
30+
ε = 0.1
3131

3232
@testset "example" begin
3333
# compute optimal transport plan and optimal transport cost
34-
γ = sinkhorn(μ, ν, C, ε, SinkhornGibbs(); maxiter=5_000, rtol=1e-9)
35-
c = sinkhorn2(μ, ν, C, ε, SinkhornGibbs(); maxiter=5_000, rtol=1e-9)
34+
γ = sinkhorn(μ, ν, C, ε, SinkhornGibbs())
35+
c = sinkhorn2(μ, ν, C, ε, SinkhornGibbs())
3636

3737
# check that plan and cost are consistent
3838
@test c dot(γ, C)
3939

4040
# compare with default algorithm
41-
γ_default = sinkhorn(μ, ν, C, ε; maxiter=5_000, rtol=1e-9)
42-
c_default = sinkhorn2(μ, ν, C, ε; maxiter=5_000, rtol=1e-9)
41+
γ_default = sinkhorn(μ, ν, C, ε)
42+
c_default = sinkhorn2(μ, ν, C, ε)
4343
@test γ_default == γ
4444
@test c_default == c
4545

4646
# compare with POT
47-
γ_pot = POT.sinkhorn(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)
48-
c_pot = POT.sinkhorn2(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)[1]
47+
γ_pot = POT.sinkhorn(μ, ν, C, ε; stopThr=1e-16)
48+
c_pot = POT.sinkhorn2(μ, ν, C, ε; stopThr=1e-16)[1]
4949
@test γ_pot γ rtol = 1e-6
5050
@test c_pot c rtol = 1e-7
5151

5252
# compute optimal transport cost with regularization term
53-
c_w_regularization = sinkhorn2(
54-
μ, ν, C, ε, SinkhornGibbs(); maxiter=5_000, regularization=true
55-
)
53+
c_w_regularization = sinkhorn2(μ, ν, C, ε, SinkhornGibbs(); regularization=true)
5654
@test c_w_regularization c + ε * sum(x -> iszero(x) ? x : x * log(x), γ)
57-
@test c_w_regularization ==
58-
sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true)
55+
@test c_w_regularization == sinkhorn2(μ, ν, C, ε; regularization=true)
5956

6057
# ensure that provided plan is used and correct
6158
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), SinkhornGibbs(); plan=γ)
@@ -78,21 +75,17 @@ Random.seed!(100)
7875

7976
# compute optimal transport plan and check that it is consistent with the
8077
# plan for individual histograms
81-
γ_all = sinkhorn(
82-
μ_batch, ν_batch, C, ε, SinkhornGibbs(); maxiter=5_000, rtol=1e-9
83-
)
78+
γ_all = sinkhorn(μ_batch, ν_batch, C, ε, SinkhornGibbs())
8479
@test size(γ_all) == (M, N, d)
8580
@test all(view(γ_all, :, :, i) γ for i in axes(γ_all, 3))
86-
@test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9)
81+
@test γ_all == sinkhorn(μ_batch, ν_batch, C, ε)
8782

8883
# compute optimal transport cost and check that it is consistent with the
8984
# cost for individual histograms
90-
c_all = sinkhorn2(
91-
μ_batch, ν_batch, C, ε, SinkhornGibbs(); maxiter=5_000, rtol=1e-9
92-
)
85+
c_all = sinkhorn2(μ_batch, ν_batch, C, ε, SinkhornGibbs())
9386
@test size(c_all) == (d,)
9487
@test all(x c for x in c_all)
95-
@test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9)
88+
@test c_all == sinkhorn2(μ_batch, ν_batch, C, ε)
9689
end
9790
end
9891

@@ -105,23 +98,23 @@ Random.seed!(100)
10598
ε32 = Float32(ε)
10699

107100
# compute optimal transport plan and optimal transport cost
108-
γ = sinkhorn(μ32, ν32, C32, ε32, SinkhornGibbs(); maxiter=5_000, rtol=1e-6)
109-
c = sinkhorn2(μ32, ν32, C32, ε32, SinkhornGibbs(); maxiter=5_000, rtol=1e-6)
101+
γ = sinkhorn(μ32, ν32, C32, ε32, SinkhornGibbs())
102+
c = sinkhorn2(μ32, ν32, C32, ε32, SinkhornGibbs())
110103
@test eltype(γ) === Float32
111104
@test typeof(c) === Float32
112105

113106
# check that plan and cost are consistent
114107
@test c dot(γ, C32)
115108

116109
# compare with default algorithm
117-
γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6)
118-
c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6)
110+
γ_default = sinkhorn(μ32, ν32, C32, ε32)
111+
c_default = sinkhorn2(μ32, ν32, C32, ε32)
119112
@test γ_default == γ
120113
@test c_default == c
121114

122115
# compare with POT
123-
γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)
124-
c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1]
116+
γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32)
117+
c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32)[1]
125118
@test map(Float32, γ_pot) γ rtol = 1e-3
126119
@test Float32(c_pot) c rtol = 1e-3
127120

@@ -135,23 +128,17 @@ Random.seed!(100)
135128

136129
# compute optimal transport plan and check that it is consistent with the
137130
# plan for individual histograms
138-
γ_all = sinkhorn(
139-
μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs(); maxiter=5_000, rtol=1e-6
140-
)
131+
γ_all = sinkhorn(μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs())
141132
@test size(γ_all) == (M, N, d)
142133
@test all(view(γ_all, :, :, i) γ for i in axes(γ_all, 3))
143-
@test γ_all ==
144-
sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6)
134+
@test γ_all == sinkhorn(μ32_batch, ν32_batch, C32, ε32)
145135

146136
# compute optimal transport cost and check that it is consistent with the
147137
# cost for individual histograms
148-
c_all = sinkhorn2(
149-
μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs(); maxiter=5_000, rtol=1e-6
150-
)
138+
c_all = sinkhorn2(μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs())
151139
@test size(c_all) == (d,)
152140
@test all(x c for x in c_all)
153-
@test c_all ==
154-
sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6)
141+
@test c_all == sinkhorn2(μ32_batch, ν32_batch, C32, ε32)
155142
end
156143
end
157144

@@ -161,7 +148,7 @@ Random.seed!(100)
161148
# together. test against gradient computed using analytic formula of Proposition 2.3 of
162149
# Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343.
163150
#
164-
ε = 0.05 # use a larger ε to avoid having to do many iterations
151+
ε = 0.1 # use a larger ε to avoid having to do many iterations
165152
# target marginal
166153
for Diff in [ReverseDiff, ForwardDiff]
167154
= Diff.gradient(log.(ν)) do xs
@@ -187,7 +174,7 @@ Random.seed!(100)
187174
end
188175
∇analytic_target = J_softmax * ∇_ot
189176
# check that gradient obtained by AD matches the analytic formula
190-
@test ∇analytic_target rtol = 1e-6
177+
@test ∇analytic_target
191178

192179
# source marginal
193180
= Diff.gradient(log.(μ)) do xs
@@ -206,7 +193,7 @@ Random.seed!(100)
206193
end
207194
∇_ot = dualvar_to_grad(solver.cache.u, ε)
208195
∇analytic_source = J_softmax * ∇_ot
209-
@test ∇analytic_source rtol = 1e-6
196+
@test ∇analytic_source
210197

211198
# both marginals
212199
= Diff.gradient(log.(vcat(μ, ν))) do xs
@@ -224,7 +211,7 @@ Random.seed!(100)
224211
end
225212
@test== ∇default
226213
∇analytic = vcat(∇analytic_source, ∇analytic_target)
227-
@test ∇analytic rtol = 1e-6
214+
@test ∇analytic
228215
end
229216
end
230217

test/entropic/sinkhorn_stabilized.jl

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,25 @@ Random.seed!(100)
2626
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
2727

2828
# regularization parameter
29-
ε = 0.01
29+
ε = 0.1
3030

3131
@testset "example" begin
3232
# compute optimal transport plan and optimal transport cost
33-
γ = sinkhorn(μ, ν, C, ε, SinkhornStabilized(); maxiter=5_000, rtol=1e-9)
34-
c = sinkhorn2(μ, ν, C, ε, SinkhornStabilized(); maxiter=5_000, rtol=1e-9)
33+
γ = sinkhorn(μ, ν, C, ε, SinkhornStabilized())
34+
c = sinkhorn2(μ, ν, C, ε, SinkhornStabilized())
3535

3636
# check that plan and cost are consistent
3737
@test c dot(γ, C)
3838

3939
# compare with POT
40-
γ_pot = POT.sinkhorn(
41-
μ, ν, C, ε; method="sinkhorn_stabilized", numItermax=5_000, stopThr=1e-9
42-
)
43-
c_pot = POT.sinkhorn2(
44-
μ, ν, C, ε; method="sinkhorn_stabilized", numItermax=5_000, stopThr=1e-9
45-
)[1]
46-
@test γ_pot γ rtol = 1e-6
47-
@test c_pot c rtol = 1e-7
40+
γ_pot = POT.sinkhorn(μ, ν, C, ε; method="sinkhorn_stabilized", stopThr=1e-16)
41+
c_pot = POT.sinkhorn2(μ, ν, C, ε; method="sinkhorn_stabilized", stopThr=1e-16)[1]
42+
@test γ_pot γ
43+
@test c_pot c
4844

4945
# compute optimal transport cost with regularization term
5046
c_w_regularization = sinkhorn2(
51-
μ, ν, C, ε, SinkhornStabilized(); maxiter=5_000, regularization=true
47+
μ, ν, C, ε, SinkhornStabilized(); regularization=true
5248
)
5349
@test c_w_regularization c + ε * sum(x -> iszero(x) ? x : x * log(x), γ)
5450

@@ -70,17 +66,13 @@ Random.seed!(100)
7066

7167
# compute optimal transport plan and check that it is consistent with the
7268
# plan for individual histograms
73-
γ_all = sinkhorn(
74-
μ_batch, ν_batch, C, ε, SinkhornStabilized(); maxiter=5_000, rtol=1e-9
75-
)
69+
γ_all = sinkhorn(μ_batch, ν_batch, C, ε, SinkhornStabilized())
7670
@test size(γ_all) == (M, N, d)
7771
@test all(view(γ_all, :, :, i) γ for i in axes(γ_all, 3))
7872

7973
# compute optimal transport cost and check that it is consistent with the
8074
# cost for individual histograms
81-
c_all = sinkhorn2(
82-
μ_batch, ν_batch, C, ε, SinkhornStabilized(); maxiter=5_000, rtol=1e-9
83-
)
75+
c_all = sinkhorn2(μ_batch, ν_batch, C, ε, SinkhornStabilized())
8476
@test size(c_all) == (d,)
8577
@test all(x c for x in c_all)
8678
end

0 commit comments

Comments
 (0)