Skip to content

Commit 70b8ae6

Browse files
authored
RENAME NONEUCLIDEAN TO ISOKINETIC (#632)
* RENAME NONEUCLIDEAN TO ISOKINETIC * REMOVE ACCIDENTALLY ADDED FILE
1 parent 7499bfd commit 70b8ae6

File tree

4 files changed

+23
-23
lines changed

4 files changed

+23
-23
lines changed

blackjax/mcmc/integrators.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
"velocity_verlet",
2727
"yoshida",
2828
"implicit_midpoint",
29-
"noneuclidean_leapfrog",
30-
"noneuclidean_mclachlan",
31-
"noneuclidean_yoshida",
29+
"isokinetic_leapfrog",
30+
"isokinetic_mclachlan",
31+
"isokinetic_yoshida",
3232
]
3333

3434

@@ -332,7 +332,7 @@ def esh_dynamics_momentum_update_one_step(
332332
return next_momentum, next_momentum, kinetic_energy_change
333333

334334

335-
def format_noneuclidean_state_output(
335+
def format_isokinetic_state_output(
336336
position,
337337
momentum,
338338
logdensity,
@@ -348,25 +348,25 @@ def format_noneuclidean_state_output(
348348
)
349349

350350

351-
def generate_noneuclidean_integrator(cofficients):
352-
def noneuclidean_integrator(
351+
def generate_isokinetic_integrator(cofficients):
352+
def isokinetic_integrator(
353353
logdensity_fn: Callable, *args, **kwargs
354354
) -> GeneralIntegrator:
355355
position_update_fn = euclidean_position_update_fn(logdensity_fn)
356356
one_step = generalized_two_stage_integrator(
357357
esh_dynamics_momentum_update_one_step,
358358
position_update_fn,
359359
cofficients,
360-
format_output_fn=format_noneuclidean_state_output,
360+
format_output_fn=format_isokinetic_state_output,
361361
)
362362
return one_step
363363

364-
return noneuclidean_integrator
364+
return isokinetic_integrator
365365

366366

367-
noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients)
368-
noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients)
369-
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)
367+
isokinetic_leapfrog = generate_isokinetic_integrator(velocity_verlet_cofficients)
368+
isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients)
369+
isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients)
370370

371371
FixedPointSolver = Callable[
372372
[Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree],

blackjax/mcmc/mclmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from jax.random import normal
2121

2222
from blackjax.base import SamplingAlgorithm
23-
from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan
23+
from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan
2424
from blackjax.types import ArrayLike, PRNGKey
2525
from blackjax.util import generate_unit_vector
2626

@@ -154,7 +154,7 @@ def __new__( # type: ignore[misc]
154154
logdensity_fn: Callable,
155155
L,
156156
step_size,
157-
integrator=noneuclidean_mclachlan,
157+
integrator=isokinetic_mclachlan,
158158
) -> SamplingAlgorithm:
159159
kernel = cls.build_kernel(logdensity_fn, integrator)
160160

tests/mcmc/test_integrators.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ def kinetic_energy(p, position=None):
140140
"algorithm": integrators.implicit_midpoint,
141141
"precision": 1e-4,
142142
},
143-
"noneuclidean_leapfrog": {"algorithm": integrators.noneuclidean_leapfrog},
144-
"noneuclidean_mclachlan": {"algorithm": integrators.noneuclidean_mclachlan},
145-
"noneuclidean_yoshida": {"algorithm": integrators.noneuclidean_yoshida},
143+
"isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog},
144+
"isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan},
145+
"isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida},
146146
}
147147

148148

@@ -239,13 +239,13 @@ def test_esh_momentum_update(self, dims):
239239
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)
240240

241241
@chex.all_variants(with_pmap=False)
242-
def test_noneuclidean_leapfrog(self):
242+
def test_isokinetic_leapfrog(self):
243243
cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]])
244244
logdensity_fn = lambda x: stats.multivariate_normal.logpdf(
245245
x, jnp.zeros([3]), cov
246246
)
247247

248-
step = self.variant(integrators.noneuclidean_leapfrog(logdensity_fn))
248+
step = self.variant(integrators.isokinetic_leapfrog(logdensity_fn))
249249

250250
rng = jax.random.key(4263456)
251251
key0, key1 = jax.random.split(rng, 2)
@@ -294,12 +294,12 @@ def test_noneuclidean_leapfrog(self):
294294
@chex.all_variants(with_pmap=False)
295295
@parameterized.parameters(
296296
[
297-
"noneuclidean_leapfrog",
298-
"noneuclidean_mclachlan",
299-
"noneuclidean_yoshida",
297+
"isokinetic_leapfrog",
298+
"isokinetic_mclachlan",
299+
"isokinetic_yoshida",
300300
],
301301
)
302-
def test_noneuclidean_integrator(self, integrator_name):
302+
def test_isokinetic_integrator(self, integrator_name):
303303
integrator = algorithms[integrator_name]
304304
cov = jnp.asarray([[1.0, 0.5], [0.5, 2.0]])
305305
logdensity_fn = lambda x: stats.multivariate_normal.logpdf(

tests/mcmc/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):
8383

8484
kernel = blackjax.mcmc.mclmc.build_kernel(
8585
logdensity_fn=logdensity_fn,
86-
integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan,
86+
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
8787
)
8888

8989
(

0 commit comments

Comments
 (0)