55import os
66from enum import Enum , auto
77from time import time
8- from typing import Any , Mapping , Optional , Union
8+ from typing import Any , Mapping
99
1010import numpy as np
1111from numpy .random import default_rng
1212
1313from cmdstanpy .utils import cmdstan_path , cmdstan_version_before , get_logger
1414
15- OptionalPath = Union [ str , os .PathLike , None ]
15+ OptionalPath = str | os .PathLike | None
1616
1717
1818class Method (Enum ):
@@ -52,19 +52,19 @@ class SamplerArgs:
5252
5353 def __init__ (
5454 self ,
55- iter_warmup : Optional [ int ] = None ,
56- iter_sampling : Optional [ int ] = None ,
55+ iter_warmup : int | None = None ,
56+ iter_sampling : int | None = None ,
5757 save_warmup : bool = False ,
58- thin : Optional [ int ] = None ,
59- max_treedepth : Optional [ int ] = None ,
60- metric_type : Optional [ str ] = None ,
61- metric_file : Union [ str , list [str ], None ] = None ,
62- step_size : Union [ float , list [float ], None ] = None ,
58+ thin : int | None = None ,
59+ max_treedepth : int | None = None ,
60+ metric_type : str | None = None ,
61+ metric_file : str | list [str ] | None = None ,
62+ step_size : float | list [float ] | None = None ,
6363 adapt_engaged : bool = True ,
64- adapt_delta : Optional [ float ] = None ,
65- adapt_init_phase : Optional [ int ] = None ,
66- adapt_metric_window : Optional [ int ] = None ,
67- adapt_step_size : Optional [ int ] = None ,
64+ adapt_delta : float | None = None ,
65+ adapt_init_phase : int | None = None ,
66+ adapt_metric_window : int | None = None ,
67+ adapt_step_size : int | None = None ,
6868 fixed_param : bool = False ,
6969 num_chains : int = 1 ,
7070 ) -> None :
@@ -74,8 +74,8 @@ def __init__(
7474 self .save_warmup = save_warmup
7575 self .thin = thin
7676 self .max_treedepth = max_treedepth
77- self .metric_type : Optional [ str ] = metric_type
78- self .metric_file : Union [ str , list [str ], None ] = metric_file
77+ self .metric_type : str | None = metric_type
78+ self .metric_file : str | list [str ] | None = metric_file
7979 self .step_size = step_size
8080 self .adapt_engaged = adapt_engaged
8181 self .adapt_delta = adapt_delta
@@ -86,7 +86,7 @@ def __init__(
8686 self .diagnostic_file = None
8787 self .num_chains = num_chains
8888
89- def validate (self , chains : Optional [ int ] ) -> None :
89+ def validate (self , chains : int | None ) -> None :
9090 """
9191 Check arguments correctness and consistency.
9292
@@ -295,16 +295,16 @@ class OptimizeArgs:
295295
296296 def __init__ (
297297 self ,
298- algorithm : Optional [ str ] = None ,
299- init_alpha : Optional [ float ] = None ,
300- iter : Optional [ int ] = None ,
298+ algorithm : str | None = None ,
299+ init_alpha : float | None = None ,
300+ iter : int | None = None ,
301301 save_iterations : bool = False ,
302- tol_obj : Optional [ float ] = None ,
303- tol_rel_obj : Optional [ float ] = None ,
304- tol_grad : Optional [ float ] = None ,
305- tol_rel_grad : Optional [ float ] = None ,
306- tol_param : Optional [ float ] = None ,
307- history_size : Optional [ int ] = None ,
302+ tol_obj : float | None = None ,
303+ tol_rel_obj : float | None = None ,
304+ tol_grad : float | None = None ,
305+ tol_rel_grad : float | None = None ,
306+ tol_param : float | None = None ,
307+ history_size : int | None = None ,
308308 jacobian : bool = False ,
309309 ) -> None :
310310 self .algorithm = algorithm or ""
@@ -319,7 +319,7 @@ def __init__(
319319 self .history_size = history_size
320320 self .jacobian = jacobian
321321
322- def validate (self , _chains : Optional [ int ] = None ) -> None :
322+ def validate (self , _chains : int | None = None ) -> None :
323323 """
324324 Check arguments correctness and consistency.
325325 """
@@ -383,13 +383,13 @@ class LaplaceArgs:
383383 """Arguments needed for laplace method."""
384384
385385 def __init__ (
386- self , mode : str , draws : Optional [ int ] = None , jacobian : bool = True
386+ self , mode : str , draws : int | None = None , jacobian : bool = True
387387 ) -> None :
388388 self .mode = mode
389389 self .jacobian = jacobian
390390 self .draws = draws
391391
392- def validate (self , _chains : Optional [ int ] = None ) -> None :
392+ def validate (self , _chains : int | None = None ) -> None :
393393 """Check arguments correctness and consistency."""
394394 if not os .path .exists (self .mode ):
395395 raise ValueError (f'Invalid path for mode file: { self .mode } ' )
@@ -411,18 +411,18 @@ class PathfinderArgs:
411411
412412 def __init__ (
413413 self ,
414- init_alpha : Optional [ float ] = None ,
415- tol_obj : Optional [ float ] = None ,
416- tol_rel_obj : Optional [ float ] = None ,
417- tol_grad : Optional [ float ] = None ,
418- tol_rel_grad : Optional [ float ] = None ,
419- tol_param : Optional [ float ] = None ,
420- history_size : Optional [ int ] = None ,
421- num_psis_draws : Optional [ int ] = None ,
422- num_paths : Optional [ int ] = None ,
423- max_lbfgs_iters : Optional [ int ] = None ,
424- num_draws : Optional [ int ] = None ,
425- num_elbo_draws : Optional [ int ] = None ,
414+ init_alpha : float | None = None ,
415+ tol_obj : float | None = None ,
416+ tol_rel_obj : float | None = None ,
417+ tol_grad : float | None = None ,
418+ tol_rel_grad : float | None = None ,
419+ tol_param : float | None = None ,
420+ history_size : int | None = None ,
421+ num_psis_draws : int | None = None ,
422+ num_paths : int | None = None ,
423+ max_lbfgs_iters : int | None = None ,
424+ num_draws : int | None = None ,
425+ num_elbo_draws : int | None = None ,
426426 save_single_paths : bool = False ,
427427 psis_resample : bool = True ,
428428 calculate_lp : bool = True ,
@@ -445,7 +445,7 @@ def __init__(
445445 self .psis_resample = psis_resample
446446 self .calculate_lp = calculate_lp
447447
448- def validate (self , _chains : Optional [ int ] = None ) -> None :
448+ def validate (self , _chains : int | None = None ) -> None :
449449 """
450450 Check arguments correctness and consistency.
451451 """
@@ -514,7 +514,7 @@ def __init__(self, csv_files: list[str]) -> None:
514514
515515 def validate (
516516 self ,
517- chains : Optional [ int ] = None , # pylint: disable=unused-argument
517+ chains : int | None = None , # pylint: disable=unused-argument
518518 ) -> None :
519519 """
520520 Check arguments correctness and consistency.
@@ -543,16 +543,16 @@ class VariationalArgs:
543543
544544 def __init__ (
545545 self ,
546- algorithm : Optional [ str ] = None ,
547- iter : Optional [ int ] = None ,
548- grad_samples : Optional [ int ] = None ,
549- elbo_samples : Optional [ int ] = None ,
550- eta : Optional [ float ] = None ,
551- adapt_iter : Optional [ int ] = None ,
546+ algorithm : str | None = None ,
547+ iter : int | None = None ,
548+ grad_samples : int | None = None ,
549+ elbo_samples : int | None = None ,
550+ eta : float | None = None ,
551+ adapt_iter : int | None = None ,
552552 adapt_engaged : bool = True ,
553- tol_rel_obj : Optional [ float ] = None ,
554- eval_elbo : Optional [ int ] = None ,
555- output_samples : Optional [ int ] = None ,
553+ tol_rel_obj : float | None = None ,
554+ eval_elbo : int | None = None ,
555+ output_samples : int | None = None ,
556556 ) -> None :
557557 self .algorithm = algorithm
558558 self .iter = iter
@@ -567,7 +567,7 @@ def __init__(
567567
568568 def validate (
569569 self ,
570- chains : Optional [ int ] = None , # pylint: disable=unused-argument
570+ chains : int | None = None , # pylint: disable=unused-argument
571571 ) -> None :
572572 """
573573 Check arguments correctness and consistency.
@@ -633,23 +633,23 @@ def __init__(
633633 self ,
634634 model_name : str ,
635635 model_exe : str ,
636- chain_ids : Optional [ list [int ]] ,
637- method_args : Union [
638- SamplerArgs ,
639- OptimizeArgs ,
640- GenerateQuantitiesArgs ,
641- VariationalArgs ,
642- LaplaceArgs ,
643- PathfinderArgs ,
644- ] ,
645- data : Union [ Mapping [str , Any ], str , None ] = None ,
646- seed : Union [ int , np .integer , list [int ], list [np .integer ], None ] = None ,
647- inits : Union [ int , float , str , list [str ], None ] = None ,
636+ chain_ids : list [int ] | None ,
637+ method_args : (
638+ SamplerArgs
639+ | OptimizeArgs
640+ | GenerateQuantitiesArgs
641+ | VariationalArgs
642+ | LaplaceArgs
643+ | PathfinderArgs
644+ ) ,
645+ data : Mapping [str , Any ] | str | None = None ,
646+ seed : int | np .integer | list [int ] | list [np .integer ] | None = None ,
647+ inits : int | float | str | list [str ] | None = None ,
648648 output_dir : OptionalPath = None ,
649- sig_figs : Optional [ int ] = None ,
649+ sig_figs : int | None = None ,
650650 save_latent_dynamics : bool = False ,
651651 save_profile : bool = False ,
652- refresh : Optional [ int ] = None ,
652+ refresh : int | None = None ,
653653 ) -> None :
654654 """Initialize object."""
655655 self .model_name = model_name
@@ -839,8 +839,8 @@ def compose_command(
839839 idx : int ,
840840 csv_file : str ,
841841 * ,
842- diagnostic_file : Optional [ str ] = None ,
843- profile_file : Optional [ str ] = None ,
842+ diagnostic_file : str | None = None ,
843+ profile_file : str | None = None ,
844844 ) -> list [str ]:
845845 """
846846 Compose CmdStan command for non-default arguments.
0 commit comments