@@ -668,88 +668,50 @@ def check_selfconsistency_discrete_logcdf(
668668 )
669669
670670
671- def check_selfconsistency_continuous_icdf (
671+ def check_selfconsistency_icdf (
672672 distribution : Distribution ,
673- paramdomains : Dict [str , Domain ],
674- decimal : Optional [int ] = None ,
673+ paramdomains : dict [str , Domain ],
674+ * ,
675+ decimal : int | None = None ,
675676 n_samples : int = 100 ,
676677) -> None :
677- """
678- Check that the icdf and logcdf functions of the distribution are consistent for a sample of probability values.
679- """
680- if decimal is None :
681- decimal = select_by_precision (float64 = 6 , float32 = 3 )
682-
683- dist = create_dist_from_paramdomains (distribution , paramdomains )
684- value = dist .type ()
685- value .name = "value"
686-
687- dist_icdf = icdf (dist , value )
688- dist_icdf_fn = pytensor .function (list (inputvars (dist_icdf )), dist_icdf )
689-
690- dist_logcdf = logcdf (dist , value )
691- dist_logcdf_fn = compile_pymc (list (inputvars (dist_logcdf )), dist_logcdf )
692-
693- domains = paramdomains .copy ()
694- domains ["value" ] = Domain (np .linspace (0 , 1 , 10 ))
695-
696- for point in product (domains , n_samples = n_samples ):
697- point = dict (point )
698- value = point .pop ("value" )
699-
700- with pytensor .config .change_flags (mode = Mode ("py" )):
701- npt .assert_almost_equal (
702- value ,
703- np .exp (dist_logcdf_fn (** point , value = dist_icdf_fn (** point , value = value ))),
704- decimal = decimal ,
705- err_msg = f"point: { point } , value: { value } " ,
706- )
678+ """Check that the icdf and logcdf functions of the distribution are consistent.
707679
708-
709- def check_selfconsistency_discrete_icdf (
710- distribution : Distribution ,
711- domain : Domain ,
712- paramdomains : Dict [str , Domain ],
713- decimal : Optional [int ] = None ,
714- n_samples : int = 100 ,
715- ) -> None :
716- """
717- Check that the icdf and logcdf functions of the distribution are
718- consistent for a sample of values in the domain of the
719- distribution.
680+ Only works with continuous distributions.
720681 """
721-
722- def ftrunc (values , decimal = 0 ):
723- return np .trunc (values * 10 ** decimal ) / (10 ** decimal )
724-
725682 if decimal is None :
726683 decimal = select_by_precision (float64 = 6 , float32 = 3 )
727684
728685 dist = create_dist_from_paramdomains (distribution , paramdomains )
729-
730- value = pt .TensorType (dtype = "float64" , shape = [])("value" )
731-
686+ if dist .type .dtype .startswith ("int" ):
687+ raise NotImplementedError (
688+ "check_selfconsistency_icdf is not robust against discrete distributions."
689+ )
690+ value = dist .astype ("float64" ).type ("value" )
732691 dist_icdf = icdf (dist , value )
733- dist_icdf_fn = pytensor . function ( list ( inputvars ( dist_icdf )), dist_icdf )
692+ dist_cdf = pt . exp ( logcdf ( dist , value ) )
734693
735- dist_logcdf = logcdf (dist , value )
736- dist_logcdf_fn = compile_pymc (list (inputvars (dist_logcdf )), dist_logcdf )
694+ py_mode = Mode ("py" )
695+ dist_icdf_fn = pytensor .function (list (inputvars (dist_icdf )), dist_icdf , mode = py_mode )
696+ dist_cdf_fn = compile (list (inputvars (dist_cdf )), dist_cdf , mode = py_mode )
737697
738698 domains = paramdomains .copy ()
739- domains ["value" ] = domain
699+ domains ["value" ] = Domain ( np . linspace ( 0 , 1 , 10 ))
740700
741701 for point in product (domains , n_samples = n_samples ):
742702 point = dict (point )
743703 value = point .pop ("value" )
744-
745- with pytensor .config .change_flags (mode = Mode ("py" )):
746- expected_value = value
747- computed_value = dist_icdf_fn (
748- ** point , value = ftrunc (np .exp (dist_logcdf_fn (** point , value = value )), decimal = decimal )
749- )
750- assert (
751- expected_value == computed_value
752- ), f"expected_value = { expected_value } , computed_value = { computed_value } , { point } "
704+ icdf_value = dist_icdf_fn (** point , value = value )
705+ recovered_value = dist_cdf_fn (
706+ ** point ,
707+ value = icdf_value ,
708+ )
709+ np .testing .assert_almost_equal (
710+ value ,
711+ recovered_value ,
712+ decimal = decimal ,
713+ err_msg = f"point: { point } " ,
714+ )
753715
754716
755717def assert_support_point_is_expected (model , expected , check_finite_logp = True ):
0 commit comments