Skip to content

Commit 1bc6f93

Browse files
authored
Removal of Algorithm classes. (#657)
* more * removing export * removal of classes, tests passing * linter * fix on test * linter * removing parametrization on test * code review updates * exporting as_top_level_api in dynamic_hmc * linter * code review update: replace imports
1 parent a5f7482 commit 1bc6f93

30 files changed

+893
-899
lines changed

blackjax/__init__.py

Lines changed: 140 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,163 @@
1+
import dataclasses
2+
from typing import Callable
3+
14
from blackjax._version import __version__
25

36
from .adaptation.chees_adaptation import chees_adaptation
47
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
58
from .adaptation.meads_adaptation import meads_adaptation
69
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
710
from .adaptation.window_adaptation import window_adaptation
11+
from .base import SamplingAlgorithm, VIAlgorithm
812
from .diagnostics import effective_sample_size as ess
913
from .diagnostics import potential_scale_reduction as rhat
10-
from .mcmc.barker import barker_proposal
11-
from .mcmc.dynamic_hmc import dynamic_hmc
12-
from .mcmc.elliptical_slice import elliptical_slice
13-
from .mcmc.ghmc import ghmc
14-
from .mcmc.hmc import hmc
15-
from .mcmc.mala import mala
16-
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
17-
from .mcmc.mclmc import mclmc
18-
from .mcmc.nuts import nuts
19-
from .mcmc.periodic_orbital import orbital_hmc
20-
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
21-
from .mcmc.rmhmc import rmhmc
14+
from .mcmc import barker
15+
from .mcmc import dynamic_hmc as _dynamic_hmc
16+
from .mcmc import elliptical_slice as _elliptical_slice
17+
from .mcmc import ghmc as _ghmc
18+
from .mcmc import hmc as _hmc
19+
from .mcmc import mala as _mala
20+
from .mcmc import marginal_latent_gaussian
21+
from .mcmc import mclmc as _mclmc
22+
from .mcmc import nuts as _nuts
23+
from .mcmc import periodic_orbital, random_walk
24+
from .mcmc import rmhmc as _rmhmc
25+
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
26+
from .mcmc.random_walk import (
27+
irmh_as_top_level_api,
28+
normal_random_walk,
29+
rmh_as_top_level_api,
30+
)
2231
from .optimizers import dual_averaging, lbfgs
23-
from .sgmcmc.csgld import csgld
24-
from .sgmcmc.sghmc import sghmc
25-
from .sgmcmc.sgld import sgld
26-
from .sgmcmc.sgnht import sgnht
27-
from .smc.adaptive_tempered import adaptive_tempered_smc
28-
from .smc.inner_kernel_tuning import inner_kernel_tuning
29-
from .smc.tempered import tempered_smc
30-
from .vi.meanfield_vi import meanfield_vi
31-
from .vi.pathfinder import pathfinder
32-
from .vi.schrodinger_follmer import schrodinger_follmer
33-
from .vi.svgd import svgd
32+
from .sgmcmc import csgld as _csgld
33+
from .sgmcmc import sghmc as _sghmc
34+
from .sgmcmc import sgld as _sgld
35+
from .sgmcmc import sgnht as _sgnht
36+
from .smc import adaptive_tempered
37+
from .smc import inner_kernel_tuning as _inner_kernel_tuning
38+
from .smc import tempered
39+
from .vi import meanfield_vi as _meanfield_vi
40+
from .vi import pathfinder as _pathfinder
41+
from .vi import schrodinger_follmer as _schrodinger_follmer
42+
from .vi import svgd as _svgd
43+
from .vi.pathfinder import PathFinderAlgorithm
44+
45+
"""
46+
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
47+
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
48+
level to be mostly functional programming in nature and reducing boilerplate code.
49+
"""
50+
51+
52+
@dataclasses.dataclass
53+
class GenerateSamplingAPI:
54+
differentiable: Callable
55+
init: Callable
56+
build_kernel: Callable
57+
58+
def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
59+
return self.differentiable(*args, **kwargs)
60+
61+
def register_factory(self, name, callable):
62+
setattr(self, name, callable)
63+
64+
65+
@dataclasses.dataclass
66+
class GenerateVariationalAPI:
67+
differentiable: Callable
68+
init: Callable
69+
step: Callable
70+
sample: Callable
71+
72+
def __call__(self, *args, **kwargs) -> VIAlgorithm:
73+
return self.differentiable(*args, **kwargs)
74+
75+
76+
@dataclasses.dataclass
77+
class GeneratePathfinderAPI:
78+
differentiable: Callable
79+
approximate: Callable
80+
sample: Callable
81+
82+
def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
83+
return self.differentiable(*args, **kwargs)
84+
85+
86+
def generate_top_level_api_from(module):
87+
return GenerateSamplingAPI(
88+
module.as_top_level_api, module.init, module.build_kernel
89+
)
90+
91+
92+
# MCMC
93+
hmc = generate_top_level_api_from(_hmc)
94+
nuts = generate_top_level_api_from(_nuts)
95+
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
96+
irmh = GenerateSamplingAPI(
97+
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
98+
)
99+
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
100+
rmhmc = generate_top_level_api_from(_rmhmc)
101+
mala = generate_top_level_api_from(_mala)
102+
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
103+
orbital_hmc = generate_top_level_api_from(periodic_orbital)
104+
105+
additive_step_random_walk = GenerateSamplingAPI(
106+
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
107+
)
108+
109+
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)
110+
111+
mclmc = generate_top_level_api_from(_mclmc)
112+
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
113+
ghmc = generate_top_level_api_from(_ghmc)
114+
barker_proposal = generate_top_level_api_from(barker)
115+
116+
hmc_family = [hmc, nuts]
117+
118+
# SMC
119+
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
120+
tempered_smc = generate_top_level_api_from(tempered)
121+
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
122+
123+
smc_family = [tempered_smc, adaptive_tempered_smc]
124+
"Step_fn returning state has a .particles attribute"
125+
126+
# stochastic gradient mcmc
127+
sgld = generate_top_level_api_from(_sgld)
128+
sghmc = generate_top_level_api_from(_sghmc)
129+
sgnht = generate_top_level_api_from(_sgnht)
130+
csgld = generate_top_level_api_from(_csgld)
131+
svgd = generate_top_level_api_from(_svgd)
132+
133+
# variational inference
134+
meanfield_vi = GenerateVariationalAPI(
135+
_meanfield_vi.as_top_level_api,
136+
_meanfield_vi.init,
137+
_meanfield_vi.step,
138+
_meanfield_vi.sample,
139+
)
140+
schrodinger_follmer = GenerateVariationalAPI(
141+
_schrodinger_follmer.as_top_level_api,
142+
_schrodinger_follmer.init,
143+
_schrodinger_follmer.step,
144+
_schrodinger_follmer.sample,
145+
)
146+
147+
pathfinder = GeneratePathfinderAPI(
148+
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
149+
)
150+
34151

35152
__all__ = [
36153
"__version__",
37154
"dual_averaging", # optimizers
38155
"lbfgs",
39-
"hmc", # mcmc
40-
"dynamic_hmc",
41-
"rmhmc",
42-
"mala",
43-
"mgrad_gaussian",
44-
"nuts",
45-
"orbital_hmc",
46-
"additive_step_random_walk",
47-
"rmh",
48-
"irmh",
49-
"mclmc",
50-
"elliptical_slice",
51-
"ghmc",
52-
"barker_proposal",
53-
"sgld", # stochastic gradient mcmc
54-
"sghmc",
55-
"sgnht",
56-
"csgld",
57156
"window_adaptation", # mcmc adaptation
58157
"meads_adaptation",
59158
"chees_adaptation",
60159
"pathfinder_adaptation",
61160
"mclmc_find_L_and_step_size", # mclmc adaptation
62-
"adaptive_tempered_smc", # smc
63-
"tempered_smc",
64-
"inner_kernel_tuning",
65-
"meanfield_vi", # variational inference
66-
"pathfinder",
67-
"schrodinger_follmer",
68-
"svgd",
69161
"ess", # diagnostics
70162
"rhat",
71163
]

blackjax/adaptation/pathfinder_adaptation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Implementation of the Pathinder warmup for the HMC family of sampling algorithms."""
15-
from typing import Callable, NamedTuple, Union
15+
from typing import Callable, NamedTuple
1616

1717
import jax
1818
import jax.numpy as jnp
1919

20-
import blackjax.mcmc as mcmc
2120
import blackjax.vi as vi
2221
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
2322
from blackjax.adaptation.step_size import (
@@ -138,7 +137,7 @@ def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]:
138137

139138

140139
def pathfinder_adaptation(
141-
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
140+
algorithm,
142141
logdensity_fn: Callable,
143142
initial_step_size: float = 1.0,
144143
target_acceptance_rate: float = 0.80,

blackjax/adaptation/window_adaptation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Implementation of the Stan warmup for the HMC family of sampling algorithms."""
15-
from typing import Callable, NamedTuple, Union
15+
from typing import Callable, NamedTuple
1616

1717
import jax
1818
import jax.numpy as jnp
1919

20-
import blackjax.mcmc as mcmc
2120
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
2221
from blackjax.adaptation.mass_matrix import (
2322
MassMatrixAdaptationState,
@@ -243,7 +242,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:
243242

244243

245244
def window_adaptation(
246-
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
245+
algorithm,
247246
logdensity_fn: Callable,
248247
is_mass_matrix_diagonal: bool = True,
249248
initial_step_size: float = 1.0,
@@ -252,7 +251,7 @@ def window_adaptation(
252251
**extra_parameters,
253252
) -> AdaptationAlgorithm:
254253
"""Adapt the value of the inverse mass matrix and step size parameters of
255-
algorithms in the HMC fmaily.
254+
algorithms in the HMC family. See Blackjax.hmc_family
256255
257256
Algorithms in the HMC family on a euclidean manifold depend on the value of
258257
at least two parameters: the step size, related to the trajectory

blackjax/mcmc/barker.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from blackjax.mcmc.proposal import static_binomial_sampling
2525
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
2626

27-
__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"]
27+
__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]
2828

2929

3030
class BarkerState(NamedTuple):
@@ -128,7 +128,10 @@ def kernel(
128128
return kernel
129129

130130

131-
class barker_proposal:
131+
def as_top_level_api(
132+
logdensity_fn: Callable,
133+
step_size: float,
134+
) -> SamplingAlgorithm:
132135
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
133136
Gaussian base kernel.
134137
@@ -179,24 +182,16 @@ class barker_proposal:
179182
180183
"""
181184

182-
init = staticmethod(init)
183-
build_kernel = staticmethod(build_kernel)
185+
kernel = build_kernel()
184186

185-
def __new__( # type: ignore[misc]
186-
cls,
187-
logdensity_fn: Callable,
188-
step_size: float,
189-
) -> SamplingAlgorithm:
190-
kernel = cls.build_kernel()
187+
def init_fn(position: ArrayLikeTree, rng_key=None):
188+
del rng_key
189+
return init(position, logdensity_fn)
191190

192-
def init_fn(position: ArrayLikeTree, rng_key=None):
193-
del rng_key
194-
return cls.init(position, logdensity_fn)
191+
def step_fn(rng_key: PRNGKey, state):
192+
return kernel(rng_key, state, logdensity_fn, step_size)
195193

196-
def step_fn(rng_key: PRNGKey, state):
197-
return kernel(rng_key, state, logdensity_fn, step_size)
198-
199-
return SamplingAlgorithm(init_fn, step_fn)
194+
return SamplingAlgorithm(init_fn, step_fn)
200195

201196

202197
def _barker_sample_nd(key, mean, a, scale):

0 commit comments

Comments
 (0)