Skip to content

Commit 2c1d779

Browse files
authored
PREVENT 1D DISTRIBUTION (#633)
* PREVENT 1D DISTRIBUTION * PREVENT 1D DISTRIBUTION
1 parent 70b8ae6 commit 2c1d779

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

blackjax/mcmc/mclmc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from blackjax.base import SamplingAlgorithm
2323
from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan
2424
from blackjax.types import ArrayLike, PRNGKey
25-
from blackjax.util import generate_unit_vector
25+
from blackjax.util import generate_unit_vector, pytree_size
2626

2727
__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"]
2828

@@ -45,6 +45,10 @@ class MCLMCInfo(NamedTuple):
4545

4646

4747
def init(position: ArrayLike, logdensity_fn, rng_key):
48+
if pytree_size(position) < 2:
49+
raise ValueError(
50+
"The target distribution must have more than 1 dimension for MCLMC."
51+
)
4852
l, g = jax.value_and_grad(logdensity_fn)(position)
4953

5054
return IntegratorState(

0 commit comments

Comments
 (0)