@@ -145,13 +145,85 @@ def test_update_multi_sigmas(self):
145145 )
146146
147147
148+ def tuned_adaptive_tempered_inference_loop (kernel , rng_key , initial_state ):
149+ def cond (carry ):
150+ _ , state , * _ = carry
151+ return state .sampler_state .lmbda < 1
152+
153+ def body (carry ):
154+ i , state , curr_loglikelihood = carry
155+ subkey = jax .random .fold_in (rng_key , i )
156+ state , info = kernel (subkey , state )
157+ return i + 1 , state , curr_loglikelihood + info .log_likelihood_increment
158+
159+ total_iter , final_state , log_likelihood = jax .lax .while_loop (
160+ cond , body , (0 , initial_state , 0.0 )
161+ )
162+ return final_state
163+
164+
148165class PretuningSMCTest (SMCLinearRegressionTestCase ):
149166 def setUp (self ):
150167 super ().setUp ()
151168 self .key = jax .random .key (42 )
152169
153170 @chex .variants (with_jit = True )
154- def test_linear_regression (self ):
171+ def test_tempered (self ):
172+ step_provider = lambda logprior_fn , loglikelihood_fn , pretune : blackjax .smc .pretuning .build_kernel (
173+ blackjax .tempered_smc ,
174+ logprior_fn ,
175+ loglikelihood_fn ,
176+ blackjax .hmc .build_kernel (),
177+ blackjax .hmc .init ,
178+ resampling .systematic ,
179+ num_mcmc_steps = 10 ,
180+ pretune_fn = pretune ,
181+ )
182+
183+ def loop (smc_kernel , init_particles , initial_parameters ):
184+ initial_state = init (
185+ blackjax .tempered_smc .init , init_particles , initial_parameters
186+ )
187+
188+ def body_fn (carry , lmbda ):
189+ i , state = carry
190+ subkey = jax .random .fold_in (self .key , i )
191+ new_state , info = smc_kernel (subkey , state , lmbda = lmbda )
192+ return (i + 1 , new_state ), (new_state , info )
193+
194+ num_tempering_steps = 10
195+ lambda_schedule = np .logspace (- 5 , 0 , num_tempering_steps )
196+
197+ (_ , result ), _ = jax .lax .scan (body_fn , (0 , initial_state ), lambda_schedule )
198+ return result
199+
200+ self .linear_regression_test_case (step_provider , loop )
201+
202+ @chex .variants (with_jit = True )
203+ def test_adaptive_tempered (self ):
204+ step_provider = lambda logprior_fn , loglikelihood_fn , pretune : blackjax .smc .pretuning .build_kernel (
205+ blackjax .adaptive_tempered_smc ,
206+ logprior_fn ,
207+ loglikelihood_fn ,
208+ blackjax .hmc .build_kernel (),
209+ blackjax .hmc .init ,
210+ resampling .systematic ,
211+ num_mcmc_steps = 10 ,
212+ pretune_fn = pretune ,
213+ target_ess = 0.5 ,
214+ )
215+
216+ def loop (smc_kernel , init_particles , initial_parameters ):
217+ initial_state = init (
218+ blackjax .tempered_smc .init , init_particles , initial_parameters
219+ )
220+ return tuned_adaptive_tempered_inference_loop (
221+ smc_kernel , self .key , initial_state
222+ )
223+
224+ self .linear_regression_test_case (step_provider , loop )
225+
226+ def linear_regression_test_case (self , step_provider , loop ):
155227 (
156228 init_particles ,
157229 logprior_fn ,
@@ -191,32 +263,11 @@ def test_linear_regression(self):
191263 positive_parameters = ["step_size" ],
192264 )
193265
194- step = blackjax .smc .pretuning .build_kernel (
195- blackjax .tempered_smc ,
196- logprior_fn ,
197- loglikelihood_fn ,
198- blackjax .hmc .build_kernel (),
199- blackjax .hmc .init ,
200- resampling .systematic ,
201- num_mcmc_steps = 10 ,
202- pretune_fn = pretune ,
203- )
266+ step = step_provider (logprior_fn , loglikelihood_fn , pretune )
204267
205- initial_state = init (
206- blackjax .tempered_smc .init , init_particles , initial_parameters
207- )
208268 smc_kernel = self .variant (step )
209269
210- def body_fn (carry , lmbda ):
211- i , state = carry
212- subkey = jax .random .fold_in (self .key , i )
213- new_state , info = smc_kernel (subkey , state , lmbda = lmbda )
214- return (i + 1 , new_state ), (new_state , info )
215-
216- num_tempering_steps = 10
217- lambda_schedule = np .logspace (- 5 , 0 , num_tempering_steps )
218-
219- (_ , result ), _ = jax .lax .scan (body_fn , (0 , initial_state ), lambda_schedule )
270+ result = loop (smc_kernel , init_particles , initial_parameters )
220271 self .assert_linear_regression_test_case (result .sampler_state )
221272 assert set (result .parameter_override .keys ()) == {
222273 "step_size" ,
0 commit comments