Skip to content

Commit a053bed

Browse files
authored
test in place (#772)
1 parent 4d4eae0 commit a053bed

File tree

1 file changed

+75
-24
lines changed

1 file changed

+75
-24
lines changed

tests/smc/test_pretuning.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,85 @@ def test_update_multi_sigmas(self):
145145
)
146146

147147

148+
def tuned_adaptive_tempered_inference_loop(kernel, rng_key, initial_state):
149+
def cond(carry):
150+
_, state, *_ = carry
151+
return state.sampler_state.lmbda < 1
152+
153+
def body(carry):
154+
i, state, curr_loglikelihood = carry
155+
subkey = jax.random.fold_in(rng_key, i)
156+
state, info = kernel(subkey, state)
157+
return i + 1, state, curr_loglikelihood + info.log_likelihood_increment
158+
159+
total_iter, final_state, log_likelihood = jax.lax.while_loop(
160+
cond, body, (0, initial_state, 0.0)
161+
)
162+
return final_state
163+
164+
148165
class PretuningSMCTest(SMCLinearRegressionTestCase):
149166
def setUp(self):
150167
super().setUp()
151168
self.key = jax.random.key(42)
152169

153170
@chex.variants(with_jit=True)
154-
def test_linear_regression(self):
171+
def test_tempered(self):
172+
step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel(
173+
blackjax.tempered_smc,
174+
logprior_fn,
175+
loglikelihood_fn,
176+
blackjax.hmc.build_kernel(),
177+
blackjax.hmc.init,
178+
resampling.systematic,
179+
num_mcmc_steps=10,
180+
pretune_fn=pretune,
181+
)
182+
183+
def loop(smc_kernel, init_particles, initial_parameters):
184+
initial_state = init(
185+
blackjax.tempered_smc.init, init_particles, initial_parameters
186+
)
187+
188+
def body_fn(carry, lmbda):
189+
i, state = carry
190+
subkey = jax.random.fold_in(self.key, i)
191+
new_state, info = smc_kernel(subkey, state, lmbda=lmbda)
192+
return (i + 1, new_state), (new_state, info)
193+
194+
num_tempering_steps = 10
195+
lambda_schedule = np.logspace(-5, 0, num_tempering_steps)
196+
197+
(_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule)
198+
return result
199+
200+
self.linear_regression_test_case(step_provider, loop)
201+
202+
@chex.variants(with_jit=True)
203+
def test_adaptive_tempered(self):
204+
step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel(
205+
blackjax.adaptive_tempered_smc,
206+
logprior_fn,
207+
loglikelihood_fn,
208+
blackjax.hmc.build_kernel(),
209+
blackjax.hmc.init,
210+
resampling.systematic,
211+
num_mcmc_steps=10,
212+
pretune_fn=pretune,
213+
target_ess=0.5,
214+
)
215+
216+
def loop(smc_kernel, init_particles, initial_parameters):
217+
initial_state = init(
218+
blackjax.tempered_smc.init, init_particles, initial_parameters
219+
)
220+
return tuned_adaptive_tempered_inference_loop(
221+
smc_kernel, self.key, initial_state
222+
)
223+
224+
self.linear_regression_test_case(step_provider, loop)
225+
226+
def linear_regression_test_case(self, step_provider, loop):
155227
(
156228
init_particles,
157229
logprior_fn,
@@ -191,32 +263,11 @@ def test_linear_regression(self):
191263
positive_parameters=["step_size"],
192264
)
193265

194-
step = blackjax.smc.pretuning.build_kernel(
195-
blackjax.tempered_smc,
196-
logprior_fn,
197-
loglikelihood_fn,
198-
blackjax.hmc.build_kernel(),
199-
blackjax.hmc.init,
200-
resampling.systematic,
201-
num_mcmc_steps=10,
202-
pretune_fn=pretune,
203-
)
266+
step = step_provider(logprior_fn, loglikelihood_fn, pretune)
204267

205-
initial_state = init(
206-
blackjax.tempered_smc.init, init_particles, initial_parameters
207-
)
208268
smc_kernel = self.variant(step)
209269

210-
def body_fn(carry, lmbda):
211-
i, state = carry
212-
subkey = jax.random.fold_in(self.key, i)
213-
new_state, info = smc_kernel(subkey, state, lmbda=lmbda)
214-
return (i + 1, new_state), (new_state, info)
215-
216-
num_tempering_steps = 10
217-
lambda_schedule = np.logspace(-5, 0, num_tempering_steps)
218-
219-
(_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule)
270+
result = loop(smc_kernel, init_particles, initial_parameters)
220271
self.assert_linear_regression_test_case(result.sampler_state)
221272
assert set(result.parameter_override.keys()) == {
222273
"step_size",

0 commit comments

Comments
 (0)