We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 70b8ae6 commit 2c1d779Copy full SHA for 2c1d779
blackjax/mcmc/mclmc.py
@@ -22,7 +22,7 @@
22
from blackjax.base import SamplingAlgorithm
23
from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan
24
from blackjax.types import ArrayLike, PRNGKey
25
-from blackjax.util import generate_unit_vector
+from blackjax.util import generate_unit_vector, pytree_size
26
27
__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"]
28
@@ -45,6 +45,10 @@ class MCLMCInfo(NamedTuple):
45
46
47
def init(position: ArrayLike, logdensity_fn, rng_key):
48
+ if pytree_size(position) < 2:
49
+ raise ValueError(
50
+ "The target distribution must have more than 1 dimension for MCLMC."
51
+ )
52
l, g = jax.value_and_grad(logdensity_fn)(position)
53
54
return IntegratorState(
0 commit comments