@@ -80,6 +80,48 @@ def args_maker():
8080 tol = 1e-3 )
8181 self ._CompileAndCheck (lax_fun , args_maker )
8282
83+ @genNamedParametersNArgs (2 )
84+ def testWrappedCauchyPdf (self , shapes , dtypes ):
85+ rng = jtu .rand_default (self .rng ())
86+ rng_uniform = jtu .rand_uniform (self .rng (), low = 1e-3 , high = 1 - 1e-3 )
87+ scipy_fun = osp_stats .wrapcauchy .pdf
88+ lax_fun = lsp_stats .wrapcauchy .pdf
89+
90+ def args_maker ():
91+ x = rng (shapes [0 ], dtypes [0 ])
92+ c = rng_uniform (shapes [1 ], dtypes [1 ])
93+ return [x , c ]
94+
95+ tol = {
96+ np .float32 : 1e-4 if jtu .test_device_matches (["tpu" ]) else 1e-5 ,
97+ np .float64 : 1e-11 ,
98+ }
99+ with jtu .strict_promotion_if_dtypes_match (dtypes ):
100+ self ._CheckAgainstNumpy (scipy_fun , lax_fun , args_maker ,
101+ check_dtypes = False , tol = tol )
102+ self ._CompileAndCheck (lax_fun , args_maker , tol = tol )
103+
104+ @genNamedParametersNArgs (2 )
105+ def testWrappedCauchyLogPdf (self , shapes , dtypes ):
106+ rng = jtu .rand_default (self .rng ())
107+ rng_uniform = jtu .rand_uniform (self .rng (), low = 1e-3 , high = 1 - 1e-3 )
108+ scipy_fun = osp_stats .wrapcauchy .logpdf
109+ lax_fun = lsp_stats .wrapcauchy .logpdf
110+
111+ def args_maker ():
112+ x = rng (shapes [0 ], dtypes [0 ])
113+ c = rng_uniform (shapes [1 ], dtypes [1 ])
114+ return [x , c ]
115+
116+ tol = {
117+ np .float32 : 1e-4 if jtu .test_device_matches (["tpu" ]) else 1e-5 ,
118+ np .float64 : 1e-11 ,
119+ }
120+ with jtu .strict_promotion_if_dtypes_match (dtypes ):
121+ self ._CheckAgainstNumpy (scipy_fun , lax_fun , args_maker ,
122+ check_dtypes = False , tol = tol )
123+ self ._CompileAndCheck (lax_fun , args_maker , tol = tol )
124+
83125 @genNamedParametersNArgs (3 )
84126 def testPoissonLogPmf (self , shapes , dtypes ):
85127 rng = jtu .rand_default (self .rng ())
0 commit comments