@@ -302,7 +302,19 @@ def logp(value, mu, cov):
302302 )
303303
304304
305- class PrecisionMvNormalRV (SymbolicRandomVariable ):
305+ class SymbolicMVNormalUsedInternally (SymbolicRandomVariable ):
306+ """Helper subclass that handles the forwarding / caching of method to `MvNormal` used internally."""
307+
308+ def __init__ (self , * args , method : str , ** kwargs ):
309+ super ().__init__ (* args , ** kwargs )
310+ self .method = method
311+
312+ def rebuild_rv (self , * args , ** kwargs ):
313+ # rv_op is a classmethod, so it doesn't have access to the instance method
314+ return self .rv_op (* args , method = self .method , ** kwargs )
315+
316+
317+ class PrecisionMvNormalRV (SymbolicMVNormalUsedInternally ):
306318 r"""A specialized multivariate normal random variable defined in terms of precision.
307319
308320 This class is introduced during specialization logprob rewrites, and not meant to be used directly.
@@ -313,14 +325,17 @@ class PrecisionMvNormalRV(SymbolicRandomVariable):
313325 _print_name = ("PrecisionMultivariateNormal" , "\\ operatorname{PrecisionMultivariateNormal}" )
314326
315327 @classmethod
316- def rv_op (cls , mean , tau , * , rng = None , size = None ):
328+ def rv_op (cls , mean , tau , * , method : str = "cholesky" , rng = None , size = None ):
317329 rng = normalize_rng_param (rng )
318330 size = normalize_size_param (size )
319331 cov = pt .linalg .inv (tau )
320- next_rng , draws = multivariate_normal (mean , cov , size = size , rng = rng ).owner .outputs
332+ next_rng , draws = multivariate_normal (
333+ mean , cov , size = size , rng = rng , method = method
334+ ).owner .outputs
321335 return cls (
322336 inputs = [rng , size , mean , tau ],
323337 outputs = [next_rng , draws ],
338+ method = method ,
324339 )(rng , size , mean , tau )
325340
326341
@@ -354,7 +369,9 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
354369 rng , size , mu , cov = node .inputs
355370 if cov .owner and cov .owner .op == matrix_inverse :
356371 tau = cov .owner .inputs [0 ]
357- return PrecisionMvNormalRV .rv_op (mu , tau , size = size , rng = rng ).owner .outputs
372+ return PrecisionMvNormalRV .rv_op (
373+ mu , tau , size = size , rng = rng , method = node .op .method
374+ ).owner .outputs
358375 return None
359376
360377
@@ -365,7 +382,7 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
365382)
366383
367384
368- class MvStudentTRV (SymbolicRandomVariable ):
385+ class MvStudentTRV (SymbolicMVNormalUsedInternally ):
369386 r"""A specialized multivariate normal random variable defined in terms of precision.
370387
371388 This class is introduced during specialization logprob rewrites, and not meant to be used directly.
@@ -376,7 +393,7 @@ class MvStudentTRV(SymbolicRandomVariable):
376393 _print_name = ("MvStudentT" , "\\ operatorname{MvStudentT}" )
377394
378395 @classmethod
379- def rv_op (cls , nu , mean , scale , * , rng = None , size = None ):
396+ def rv_op (cls , nu , mean , scale , * , method : str = "cholesky" , rng = None , size = None ):
380397 nu = pt .as_tensor (nu )
381398 mean = pt .as_tensor (mean )
382399 scale = pt .as_tensor (scale )
@@ -387,14 +404,15 @@ def rv_op(cls, nu, mean, scale, *, rng=None, size=None):
387404 size = implicit_size_from_params (nu , mean , scale , ndims_params = cls .ndims_params )
388405
389406 next_rng , mv_draws = multivariate_normal (
390- mean .zeros_like (), scale , size = size , rng = rng
407+ mean .zeros_like (), scale , size = size , rng = rng , method = method
391408 ).owner .outputs
392409 next_rng , chi2_draws = chisquare (nu , size = size , rng = next_rng ).owner .outputs
393410 draws = mean + (mv_draws / pt .sqrt (chi2_draws / nu )[..., None ])
394411
395412 return cls (
396413 inputs = [rng , size , nu , mean , scale ],
397414 outputs = [next_rng , draws ],
415+ method = method ,
398416 )(rng , size , nu , mean , scale )
399417
400418
@@ -1923,12 +1941,12 @@ def logp(value, mu, rowchol, colchol):
19231941 return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
19241942
19251943
1926- class KroneckerNormalRV (SymbolicRandomVariable ):
1944+ class KroneckerNormalRV (SymbolicMVNormalUsedInternally ):
19271945 ndim_supp = 1
19281946 _print_name = ("KroneckerNormal" , "\\ operatorname{KroneckerNormal}" )
19291947
19301948 @classmethod
1931- def rv_op (cls , mu , sigma , * covs , size = None , rng = None ):
1949+ def rv_op (cls , mu , sigma , * covs , method : str = "cholesky" , size = None , rng = None ):
19321950 mu = pt .as_tensor (mu )
19331951 sigma = pt .as_tensor (sigma )
19341952 covs = [pt .as_tensor (cov ) for cov in covs ]
@@ -1937,7 +1955,9 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
19371955
19381956 cov = reduce (pt .linalg .kron , covs )
19391957 cov = cov + sigma ** 2 * pt .eye (cov .shape [- 2 ])
1940- next_rng , draws = multivariate_normal (mean = mu , cov = cov , size = size , rng = rng ).owner .outputs
1958+ next_rng , draws = multivariate_normal (
1959+ mean = mu , cov = cov , size = size , rng = rng , method = method
1960+ ).owner .outputs
19411961
19421962 covs_sig = "," .join (f"(a{ i } ,b{ i } )" for i in range (len (covs )))
19431963 extended_signature = f"[rng],[size],(m),(),{ covs_sig } ->[rng],(m)"
@@ -1946,6 +1966,7 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
19461966 inputs = [rng , size , mu , sigma , * covs ],
19471967 outputs = [next_rng , draws ],
19481968 extended_signature = extended_signature ,
1969+ method = method ,
19491970 )(rng , size , mu , sigma , * covs )
19501971
19511972
0 commit comments