Skip to content

Commit 148c028

Browse files
authored
remove labels (#716)
1 parent 27dfc9e commit 148c028

File tree

3 files changed

+26
-25
lines changed

3 files changed

+26
-25
lines changed

blackjax/adaptation/window_adaptation.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
dual_averaging_adaptation,
2929
)
3030
from blackjax.base import AdaptationAlgorithm
31-
from blackjax.progress_bar import progress_bar_scan
31+
from blackjax.progress_bar import gen_scan_fn
3232
from blackjax.types import Array, ArrayLikeTree, PRNGKey
3333
from blackjax.util import pytree_size
3434

@@ -333,23 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
333333

334334
if progress_bar:
335335
print("Running window adaptation")
336-
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
337-
start_state = ((init_state, init_adaptation_state), -1)
338-
else:
339-
one_step_ = jax.jit(one_step)
340-
start_state = (init_state, init_adaptation_state)
341-
336+
scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar)
337+
start_state = (init_state, init_adaptation_state)
342338
keys = jax.random.split(rng_key, num_steps)
343339
schedule = build_schedule(num_steps)
344-
last_state, info = jax.lax.scan(
345-
one_step_,
340+
last_state, info = scan_fn(
341+
one_step,
346342
start_state,
347343
(jnp.arange(num_steps), keys, schedule),
348344
)
349345

350-
if progress_bar:
351-
last_state, _ = last_state
352-
353346
last_chain_state, last_warmup_state, *_ = last_state
354347

355348
step_size, inverse_mass_matrix = adapt_final(last_warmup_state)

blackjax/progress_bar.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,17 @@ def wrapper_progress_bar(carry, x):
9494
return wrapper_progress_bar
9595

9696
return _progress_bar_scan
97+
98+
99+
def gen_scan_fn(num_samples, progress_bar, print_rate=None):
100+
if progress_bar:
101+
102+
def scan_wrap(f, init, *args, **kwargs):
103+
func = progress_bar_scan(num_samples, print_rate)(f)
104+
carry = (init, -1)
105+
(last_state, _), output = lax.scan(func, carry, *args, **kwargs)
106+
return last_state, output
107+
108+
return scan_wrap
109+
else:
110+
return lax.scan

blackjax/util.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from jax.tree_util import tree_leaves
1212

1313
from blackjax.base import SamplingAlgorithm, VIAlgorithm
14-
from blackjax.progress_bar import progress_bar_scan
14+
from blackjax.progress_bar import gen_scan_fn
1515
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
1616

1717

@@ -225,18 +225,12 @@ def one_step(average_and_state, xs, return_state):
225225
one_step = jax.jit(partial(one_step, return_state=return_state_history))
226226

227227
xs = (jnp.arange(num_steps), keys)
228-
if progress_bar:
229-
one_step = progress_bar_scan(num_steps)(one_step)
230-
(((_, average), final_state), _), history = lax.scan(
231-
one_step,
232-
(((0, expectation(transform(initial_state))), initial_state), -1),
233-
xs,
234-
)
235-
236-
else:
237-
((_, average), final_state), history = lax.scan(
238-
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
239-
)
228+
scan_fn = gen_scan_fn(num_steps, progress_bar)
229+
((_, average), final_state), history = scan_fn(
230+
one_step,
231+
((0, expectation(transform(initial_state))), initial_state),
232+
xs,
233+
)
240234

241235
if not return_state_history:
242236
return average, transform(final_state)

0 commit comments

Comments
 (0)