@@ -27,35 +27,32 @@ Random.seed!(100)
2727 C = pairwise (SqEuclidean (), rand (1 , M), rand (1 , N); dims= 2 )
2828
2929 # regularization parameter
30- ε = 0.01
30+ ε = 0.1
3131
3232 @testset " example" begin
3333 # compute optimal transport plan and optimal transport cost
34- γ = sinkhorn (μ, ν, C, ε, SinkhornGibbs (); maxiter = 5_000 , rtol = 1e-9 )
35- c = sinkhorn2 (μ, ν, C, ε, SinkhornGibbs (); maxiter = 5_000 , rtol = 1e-9 )
34+ γ = sinkhorn (μ, ν, C, ε, SinkhornGibbs ())
35+ c = sinkhorn2 (μ, ν, C, ε, SinkhornGibbs ())
3636
3737 # check that plan and cost are consistent
3838 @test c ≈ dot (γ, C)
3939
4040 # compare with default algorithm
41- γ_default = sinkhorn (μ, ν, C, ε; maxiter = 5_000 , rtol = 1e-9 )
42- c_default = sinkhorn2 (μ, ν, C, ε; maxiter = 5_000 , rtol = 1e-9 )
41+ γ_default = sinkhorn (μ, ν, C, ε)
42+ c_default = sinkhorn2 (μ, ν, C, ε)
4343 @test γ_default == γ
4444 @test c_default == c
4545
4646 # compare with POT
47- γ_pot = POT. sinkhorn (μ, ν, C, ε; numItermax = 5_000 , stopThr= 1e-9 )
48- c_pot = POT. sinkhorn2 (μ, ν, C, ε; numItermax = 5_000 , stopThr= 1e-9 )[1 ]
47+ γ_pot = POT. sinkhorn (μ, ν, C, ε; stopThr= 1e-16 )
48+ c_pot = POT. sinkhorn2 (μ, ν, C, ε; stopThr= 1e-16 )[1 ]
4949 @test γ_pot ≈ γ rtol = 1e-6
5050 @test c_pot ≈ c rtol = 1e-7
5151
5252 # compute optimal transport cost with regularization term
53- c_w_regularization = sinkhorn2 (
54- μ, ν, C, ε, SinkhornGibbs (); maxiter= 5_000 , regularization= true
55- )
53+ c_w_regularization = sinkhorn2 (μ, ν, C, ε, SinkhornGibbs (); regularization= true )
5654 @test c_w_regularization ≈ c + ε * sum (x -> iszero (x) ? x : x * log (x), γ)
57- @test c_w_regularization ==
58- sinkhorn2 (μ, ν, C, ε; maxiter= 5_000 , regularization= true )
55+ @test c_w_regularization == sinkhorn2 (μ, ν, C, ε; regularization= true )
5956
6057 # ensure that provided plan is used and correct
6158 c2 = sinkhorn2 (similar (μ), similar (ν), C, rand (), SinkhornGibbs (); plan= γ)
@@ -78,21 +75,17 @@ Random.seed!(100)
7875
7976 # compute optimal transport plan and check that it is consistent with the
8077 # plan for individual histograms
81- γ_all = sinkhorn (
82- μ_batch, ν_batch, C, ε, SinkhornGibbs (); maxiter= 5_000 , rtol= 1e-9
83- )
78+ γ_all = sinkhorn (μ_batch, ν_batch, C, ε, SinkhornGibbs ())
8479 @test size (γ_all) == (M, N, d)
8580 @test all (view (γ_all, :, :, i) ≈ γ for i in axes (γ_all, 3 ))
86- @test γ_all == sinkhorn (μ_batch, ν_batch, C, ε; maxiter = 5_000 , rtol = 1e-9 )
81+ @test γ_all == sinkhorn (μ_batch, ν_batch, C, ε)
8782
8883 # compute optimal transport cost and check that it is consistent with the
8984 # cost for individual histograms
90- c_all = sinkhorn2 (
91- μ_batch, ν_batch, C, ε, SinkhornGibbs (); maxiter= 5_000 , rtol= 1e-9
92- )
85+ c_all = sinkhorn2 (μ_batch, ν_batch, C, ε, SinkhornGibbs ())
9386 @test size (c_all) == (d,)
9487 @test all (x ≈ c for x in c_all)
95- @test c_all == sinkhorn2 (μ_batch, ν_batch, C, ε; maxiter = 5_000 , rtol = 1e-9 )
88+ @test c_all == sinkhorn2 (μ_batch, ν_batch, C, ε)
9689 end
9790 end
9891
@@ -105,23 +98,23 @@ Random.seed!(100)
10598 ε32 = Float32 (ε)
10699
107100 # compute optimal transport plan and optimal transport cost
108- γ = sinkhorn (μ32, ν32, C32, ε32, SinkhornGibbs (); maxiter = 5_000 , rtol = 1e-6 )
109- c = sinkhorn2 (μ32, ν32, C32, ε32, SinkhornGibbs (); maxiter = 5_000 , rtol = 1e-6 )
101+ γ = sinkhorn (μ32, ν32, C32, ε32, SinkhornGibbs ())
102+ c = sinkhorn2 (μ32, ν32, C32, ε32, SinkhornGibbs ())
110103 @test eltype (γ) === Float32
111104 @test typeof (c) === Float32
112105
113106 # check that plan and cost are consistent
114107 @test c ≈ dot (γ, C32)
115108
116109 # compare with default algorithm
117- γ_default = sinkhorn (μ32, ν32, C32, ε32; maxiter = 5_000 , rtol = 1e-6 )
118- c_default = sinkhorn2 (μ32, ν32, C32, ε32; maxiter = 5_000 , rtol = 1e-6 )
110+ γ_default = sinkhorn (μ32, ν32, C32, ε32)
111+ c_default = sinkhorn2 (μ32, ν32, C32, ε32)
119112 @test γ_default == γ
120113 @test c_default == c
121114
122115 # compare with POT
123- γ_pot = POT. sinkhorn (μ32, ν32, C32, ε32; numItermax = 5_000 , stopThr = 1e-6 )
124- c_pot = POT. sinkhorn2 (μ32, ν32, C32, ε32; numItermax = 5_000 , stopThr = 1e-6 )[1 ]
116+ γ_pot = POT. sinkhorn (μ32, ν32, C32, ε32)
117+ c_pot = POT. sinkhorn2 (μ32, ν32, C32, ε32)[1 ]
125118 @test map (Float32, γ_pot) ≈ γ rtol = 1e-3
126119 @test Float32 (c_pot) ≈ c rtol = 1e-3
127120
@@ -135,23 +128,17 @@ Random.seed!(100)
135128
136129 # compute optimal transport plan and check that it is consistent with the
137130 # plan for individual histograms
138- γ_all = sinkhorn (
139- μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs (); maxiter= 5_000 , rtol= 1e-6
140- )
131+ γ_all = sinkhorn (μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs ())
141132 @test size (γ_all) == (M, N, d)
142133 @test all (view (γ_all, :, :, i) ≈ γ for i in axes (γ_all, 3 ))
143- @test γ_all ==
144- sinkhorn (μ32_batch, ν32_batch, C32, ε32; maxiter= 5_000 , rtol= 1e-6 )
134+ @test γ_all == sinkhorn (μ32_batch, ν32_batch, C32, ε32)
145135
146136 # compute optimal transport cost and check that it is consistent with the
147137 # cost for individual histograms
148- c_all = sinkhorn2 (
149- μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs (); maxiter= 5_000 , rtol= 1e-6
150- )
138+ c_all = sinkhorn2 (μ32_batch, ν32_batch, C32, ε32, SinkhornGibbs ())
151139 @test size (c_all) == (d,)
152140 @test all (x ≈ c for x in c_all)
153- @test c_all ==
154- sinkhorn2 (μ32_batch, ν32_batch, C32, ε32; maxiter= 5_000 , rtol= 1e-6 )
141+ @test c_all == sinkhorn2 (μ32_batch, ν32_batch, C32, ε32)
155142 end
156143 end
157144
@@ -161,7 +148,7 @@ Random.seed!(100)
161148 # together. test against gradient computed using analytic formula of Proposition 2.3 of
162149 # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343.
163150 #
164- ε = 0.05 # use a larger ε to avoid having to do many iterations
151+ ε = 0.1 # use a larger ε to avoid having to do many iterations
165152 # target marginal
166153 for Diff in [ReverseDiff, ForwardDiff]
167154 ∇ = Diff. gradient (log .(ν)) do xs
@@ -187,7 +174,7 @@ Random.seed!(100)
187174 end
188175 ∇analytic_target = J_softmax * ∇_ot
189176 # check that gradient obtained by AD matches the analytic formula
190- @test ∇ ≈ ∇analytic_target rtol = 1e-6
177+ @test ∇ ≈ ∇analytic_target
191178
192179 # source marginal
193180 ∇ = Diff. gradient (log .(μ)) do xs
@@ -206,7 +193,7 @@ Random.seed!(100)
206193 end
207194 ∇_ot = dualvar_to_grad (solver. cache. u, ε)
208195 ∇analytic_source = J_softmax * ∇_ot
209- @test ∇ ≈ ∇analytic_source rtol = 1e-6
196+ @test ∇ ≈ ∇analytic_source
210197
211198 # both marginals
212199 ∇ = Diff. gradient (log .(vcat (μ, ν))) do xs
@@ -224,7 +211,7 @@ Random.seed!(100)
224211 end
225212 @test ∇ == ∇default
226213 ∇analytic = vcat (∇analytic_source, ∇analytic_target)
227- @test ∇ ≈ ∇analytic rtol = 1e-6
214+ @test ∇ ≈ ∇analytic
228215 end
229216 end
230217
0 commit comments