Skip to content

Commit 3c1c0d6

Browse files
committed
add getparams and setparams!!
1 parent 4442783 commit 3c1c0d6

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -23,7 +23,7 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
2323
AdvancedMHStructArraysExt = "StructArrays"
2424

2525
[compat]
26-
AbstractMCMC = "5"
26+
AbstractMCMC = "5.5"
2727
DiffResults = "1"
2828
Distributions = "0.25"
2929
FillArrays = "1"

src/AdvancedMH.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module AdvancedMH
22

33
# Import the relevant libraries.
44
using AbstractMCMC
5+
using AbstractMCMC: BangBang
56
using Distributions
67
using LinearAlgebra: I
78
using FillArrays: Zeros
@@ -140,6 +141,15 @@ function __init__()
140141
end
141142
end
142143

144+
# AbstractMCMC.jl interface
145+
function AbstractMCMC.getparams(t::Transition)
146+
return t.params
147+
end
148+
149+
function AbstractMCMC.setparams!!(t::Transition, params)
150+
return BangBang.setproperty!!(t, :params, params)
151+
end
152+
143153
# Include inference methods.
144154
include("proposal.jl")
145155
include("mh-core.jl")

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using AdvancedMH
2+
using AbstractMCMC
23
using DiffResults
34
using Distributions
45
using ForwardDiff
@@ -33,6 +34,15 @@ include("util.jl")
3334
LogDensityProblems.logdensity(::typeof(density), θ) = density(θ)
3435
LogDensityProblems.dimension(::typeof(density)) = 2
3536

37+
@testset "getparams/setparams!! (AbstractMCMC interface)" begin
38+
test_spl = StaticMH([Normal(0, 1), Normal(0, 1)])
39+
t, _ = AbstractMCMC.step(Random.default_rng(), model, test_spl)
40+
@test AbstractMCMC.getparams(t) == t.params
41+
@test AbstractMCMC.setparams!!(t, AbstractMCMC.getparams(t)) == t
42+
t_replaced = AbstractMCMC.setparams!!(t, (μ=1.0, σ=2.0))
43+
@test t_replaced.params ===1.0, σ=2.0)
44+
end
45+
3646
@testset "StaticMH" begin
3747
# Set up our sampler with initial parameters.
3848
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])

0 commit comments

Comments
 (0)