Skip to content

Commit 3f92393

Browse files
authored
Fix deprecated call to jnp.clip (#664)
1 parent 1bc6f93 commit 3f92393

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

blackjax/mcmc/proposal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def progressive_biased_sampling(
153153
biases the transition away from the trajectory's initial state.
154154
155155
"""
156-
p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1)
156+
p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), max=1)
157157
do_accept = jax.random.bernoulli(rng_key, p_accept)
158158
new_weight = jnp.logaddexp(proposal.weight, new_proposal.weight)
159159
new_sum_log_p_accept = jnp.logaddexp(
@@ -224,7 +224,7 @@ def static_binomial_sampling(
224224
then the new proposal is accepted with probability 1.
225225
226226
"""
227-
p_accept = jnp.clip(jnp.exp(log_p_accept), a_max=1)
227+
p_accept = jnp.clip(jnp.exp(log_p_accept), max=1)
228228
do_accept = jax.random.bernoulli(rng_key, p_accept)
229229
info = do_accept, p_accept, None
230230
return (
@@ -253,7 +253,7 @@ def nonreversible_slice_sampling(
253253
to the accept/reject step of a current state and new proposal.
254254
255255
"""
256-
p_accept = jnp.clip(jnp.exp(delta_energy), a_max=1)
256+
p_accept = jnp.clip(jnp.exp(delta_energy), max=1)
257257
do_accept = jnp.log(jnp.abs(slice)) <= delta_energy
258258
slice_next = slice * (jnp.exp(-delta_energy) * do_accept + (1 - do_accept))
259259
info = do_accept, p_accept, slice_next

0 commit comments

Comments
 (0)