@@ -1382,10 +1382,6 @@ def _rebuild_group_mappings(group, model):
13821382class TraceSpec :
13831383 sample_vars : list
13841384 test_point : collections .OrderedDict
1385- computed_var_names : list [str ]
1386- input_vars : list
1387- compute_fn : Any
1388- value_var_names : list [str ]
13891385
13901386
13911387class Approximation (WithMemoization ):
@@ -1545,7 +1541,7 @@ def _collect_sample_vars(self, model, sample_names):
15451541
15461542 def _compute_missing_trace_values (self , model , samples , missing_vars ):
15471543 if not missing_vars :
1548- return {}, [], [], None
1544+ return {}
15491545 input_vars = model .value_vars
15501546 base_point = model .initial_point ()
15511547 point = {
@@ -1565,7 +1561,7 @@ def _compute_missing_trace_values(self, model, samples, missing_vars):
15651561 if not isinstance (raw_values , list | tuple ):
15661562 raw_values = [raw_values ]
15671563 values = {var .name : np .asarray (value ) for var , value in zip (missing_vars , raw_values )}
1568- return values , [ var . name for var in missing_vars ], list ( input_vars ), compute_fn
1564+ return values
15691565
15701566 def _build_trace_spec (self , model , samples ):
15711567 sample_names = sorted (samples .keys ())
@@ -1586,43 +1582,15 @@ def _build_trace_spec(self, model, samples):
15861582 continue
15871583 missing_vars .append (var )
15881584
1589- values , computed_var_names , input_vars , compute_fn = self ._compute_missing_trace_values (
1590- model , samples , missing_vars
1591- )
1585+ values = self ._compute_missing_trace_values (model , samples , missing_vars )
15921586 for name , value in values .items ():
15931587 test_point [name ] = value
15941588
15951589 return TraceSpec (
15961590 sample_vars = sample_vars ,
15971591 test_point = test_point ,
1598- computed_var_names = computed_var_names ,
1599- input_vars = input_vars ,
1600- compute_fn = compute_fn ,
1601- value_var_names = [var .name for var in model .value_vars ],
16021592 )
16031593
1604- def _augment_samples_with_computed (self , model , samples , spec , draws ):
1605- if not spec .computed_var_names :
1606- return
1607-
1608- computed = {name : [] for name in spec .computed_var_names }
1609- input_names = [var .name for var in spec .input_vars ]
1610- for i in range (draws ):
1611- inputs = {}
1612- for name in input_names :
1613- if name in samples :
1614- inputs [name ] = samples [name ][i ]
1615- else :
1616- inputs [name ] = spec .test_point [name ]
1617- outputs = spec .compute_fn (inputs )
1618- if not isinstance (outputs , list | tuple ):
1619- outputs = [outputs ]
1620- for name , value in zip (spec .computed_var_names , outputs ):
1621- computed [name ].append (np .asarray (value ))
1622-
1623- for name , values in computed .items ():
1624- samples [name ] = np .stack (values )
1625-
16261594 inputs = property (lambda self : self .collect ("input" ))
16271595 symbolic_randoms = property (lambda self : self .collect ("symbolic_random" ))
16281596
@@ -1958,13 +1926,11 @@ def sample(
19581926 (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
19591927 samples : dict = self .sample_dict_fn (draws , model = model , random_seed = random_seed )
19601928 spec = self ._build_trace_spec (model , samples )
1961- self ._augment_samples_with_computed (model , samples , spec , draws )
1962- if spec .computed_var_names :
1963- spec = self ._build_trace_spec (model , samples )
19641929
19651930 from collections import OrderedDict
19661931
19671932 default_point = model .initial_point ()
1933+ value_var_names = [var .name for var in model .value_vars ]
19681934 points = (
19691935 OrderedDict (
19701936 (
@@ -1973,7 +1939,7 @@ def sample(
19731939 if name in samples and len (samples [name ]) > i
19741940 else np .asarray (spec .test_point .get (name , default_point [name ])),
19751941 )
1976- for name in spec . value_var_names
1942+ for name in value_var_names
19771943 )
19781944 for i in range (draws )
19791945 )
0 commit comments