@@ -107,14 +107,16 @@ def test_data(inline_views):
107107 with pm .Model (coords_mutable = {"test_dim" : range (3 )}) as m_old :
108108 x = pm .MutableData ("x" , [0.0 , 1.0 , 2.0 ], dims = ("test_dim" ,))
109109 y = pm .MutableData ("y" , [10.0 , 11.0 , 12.0 ], dims = ("test_dim" ,))
110+ sigma = pm .MutableData ("sigma" , [1.0 ], shape = (1 ,))
110111 b0 = pm .ConstantData ("b0" , np .zeros ((1 ,)))
111112 b1 = pm .DiracDelta ("b1" , 1.0 )
112113 mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
113- obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
114+ obs = pm .Normal ("obs" , mu = mu , sigma = sigma , observed = y , dims = ("test_dim" ,))
114115
115116 m_fgraph , memo = fgraph_from_model (m_old , inlined_views = inline_views )
116117 assert isinstance (memo [x ].owner .op , ModelNamed )
117118 assert isinstance (memo [y ].owner .op , ModelNamed )
119+ assert isinstance (memo [sigma ].owner .op , ModelNamed )
118120 assert isinstance (memo [b0 ].owner .op , ModelNamed )
119121 mu_inp = memo [mu ].owner .inputs [0 ]
120122 obs = memo [obs ]
@@ -124,10 +126,13 @@ def test_data(inline_views):
124126 assert mu_inp .owner .inputs [1 ].owner .inputs [1 ] is memo [x ].owner .inputs [0 ]
125127 # ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims)
126128 assert obs .owner .inputs [1 ] is memo [y ].owner .inputs [0 ]
129+ # ObservedRV(Normal(..., sigma), ...) not ObservedRV(Normal(..., Named(sigma)), ...)
130+ assert obs .owner .inputs [0 ].owner .inputs [4 ] is memo [sigma ].owner .inputs [0 ]
127131 else :
128132 assert mu_inp .owner .inputs [0 ] is memo [b0 ]
129133 assert mu_inp .owner .inputs [1 ].owner .inputs [1 ] is memo [x ]
130134 assert obs .owner .inputs [1 ] is memo [y ]
135+ assert obs .owner .inputs [0 ].owner .inputs [4 ] is memo [sigma ]
131136
132137 m_new = model_from_fgraph (m_fgraph )
133138
@@ -140,9 +145,17 @@ def test_data(inline_views):
140145 # Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
141146 assert not same_storage (m_new ["x" ], x )
142147 assert not same_storage (m_new ["y" ], y )
148+ assert not same_storage (m_new ["sigma" ], sigma )
143149 assert not same_storage (m_new ["b1" ].owner .inputs [0 ], b1 .owner .inputs [0 ])
144150 assert not same_storage (m_new .dim_lengths ["test_dim" ], m_old .dim_lengths ["test_dim" ])
145151
152+ # Check they have the same type
153+ assert m_new ["x" ].type == x .type
154+ assert m_new ["y" ].type == y .type
155+ assert m_new ["sigma" ].type == sigma .type
156+ assert m_new ["b1" ].owner .inputs [0 ].type == b1 .owner .inputs [0 ].type
157+ assert m_new .dim_lengths ["test_dim" ].type == m_old .dim_lengths ["test_dim" ].type
158+
146159 # Updating model shared variables in new model, doesn't affect old one
147160 with m_new :
148161 pm .set_data ({"x" : [100.0 , 200.0 ]}, coords = {"test_dim" : range (2 )})
@@ -155,22 +168,31 @@ def test_data(inline_views):
155168@config .change_flags (floatX = "float64" ) # Avoid downcasting Ops in the graph
156169def test_shared_variable ():
157170 """Test that user defined shared variables (other than RNGs) aren't copied."""
158- x = shared (np .array ([1 , 2 , 3.0 ]), name = "x" )
159- y = shared (np .array ([1 , 2 , 3.0 ]), name = "y" )
171+ mu = shared (np .array ([1 , 2 , 3.0 ]), shape = (None ,), name = "mu" )
172+ sigma = shared (np .array ([1.0 ]), shape = (1 ,), name = "sigma" )
173+ obs = shared (np .array ([1 , 2 , 3.0 ]), shape = (3 ,), name = "obs" )
160174
161175 with pm .Model () as m_old :
162- test = pm .Normal ("test" , mu = x , observed = y )
176+ test = pm .Normal ("test" , mu = mu , sigma = sigma , observed = obs )
163177
164- assert test .owner .inputs [3 ] is x
165- assert m_old .rvs_to_values [test ] is y
178+ assert test .owner .inputs [3 ] is mu
179+ assert test .owner .inputs [4 ] is sigma
180+ assert m_old .rvs_to_values [test ] is obs
166181
167182 m_new = clone_model (m_old )
168183 test_new = m_new ["test" ]
169184 # Shared Variables are cloned but still point to the same memory
170- assert test_new .owner .inputs [3 ] is not x
171- assert m_new .rvs_to_values [test_new ] is not y
172- assert same_storage (test_new .owner .inputs [3 ], x )
173- assert same_storage (m_new .rvs_to_values [test_new ], y )
185+ mu_new , sigma_new = test_new .owner .inputs [3 :5 ]
186+ obs_new = m_new .rvs_to_values [test_new ]
187+ assert mu_new is not mu
188+ assert sigma_new is not sigma
189+ assert obs_new is not obs
190+ assert mu_new .type == mu .type
191+ assert sigma_new .type == sigma .type
192+ assert obs_new .type == obs .type
193+ assert same_storage (mu , mu_new )
194+ assert same_storage (sigma , sigma_new )
195+ assert same_storage (obs , obs_new )
174196
175197
176198@pytest .mark .parametrize ("inline_views" , (False , True ))
0 commit comments