Skip to content

Commit 3778265

Browse files
author
jax authors
committed
Merge pull request #18126 from niqodea:wrapcauchy
PiperOrigin-RevId: 574572631
2 parents 88fe0da + 890b762 commit 3778265

File tree

5 files changed

+112
-0
lines changed

5 files changed

+112
-0
lines changed

docs/jax.scipy.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,12 @@ jax.scipy.stats.vonmises
446446

447447
logpdf
448448
pdf
449+
450+
jax.scipy.stats.wrapcauchy
451+
~~~~~~~~~~~~~~~~~~~~~~~~~~
452+
.. automodule:: jax.scipy.stats.wrapcauchy
453+
.. autosummary::
454+
:toctree: _autosummary
455+
456+
logpdf
457+
pdf

jax/_src/scipy/stats/wrapcauchy.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import scipy.stats as osp_stats
17+
from jax import lax
18+
import jax.numpy as jnp
19+
from jax._src.lax.lax import _const as _lax_const
20+
from jax._src.numpy.util import _wraps, promote_args_inexact
21+
from jax._src.typing import Array, ArrayLike
22+
23+
24+
@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False)
25+
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
26+
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
27+
return jnp.where(
28+
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
29+
jnp.where(
30+
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, jnp.pi * 2)),
31+
jnp.log(1 - c * c) - jnp.log(2 * jnp.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
32+
-jnp.inf,
33+
),
34+
jnp.nan,
35+
)
36+
37+
@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False)
38+
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
39+
return lax.exp(logpdf(x, c))

jax/scipy/stats/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@
4040
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
4141
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata
4242
from jax.scipy.stats import vonmises as vonmises
43+
from jax.scipy.stats import wrapcauchy as wrapcauchy

jax/scipy/stats/wrapcauchy.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Note: import <name> as <name> is required for names to be exported.
16+
# See PEP 484 & https://github.com/google/jax/issues/7570
17+
18+
from jax._src.scipy.stats.wrapcauchy import (
19+
logpdf as logpdf,
20+
pdf as pdf,
21+
)

tests/scipy_stats_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,48 @@ def args_maker():
8080
tol=1e-3)
8181
self._CompileAndCheck(lax_fun, args_maker)
8282

83+
@genNamedParametersNArgs(2)
84+
def testWrappedCauchyPdf(self, shapes, dtypes):
85+
rng = jtu.rand_default(self.rng())
86+
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
87+
scipy_fun = osp_stats.wrapcauchy.pdf
88+
lax_fun = lsp_stats.wrapcauchy.pdf
89+
90+
def args_maker():
91+
x = rng(shapes[0], dtypes[0])
92+
c = rng_uniform(shapes[1], dtypes[1])
93+
return [x, c]
94+
95+
tol = {
96+
np.float32: 1e-4 if jtu.test_device_matches(["tpu"]) else 1e-5,
97+
np.float64: 1e-11,
98+
}
99+
with jtu.strict_promotion_if_dtypes_match(dtypes):
100+
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
101+
check_dtypes=False, tol=tol)
102+
self._CompileAndCheck(lax_fun, args_maker, tol=tol)
103+
104+
@genNamedParametersNArgs(2)
105+
def testWrappedCauchyLogPdf(self, shapes, dtypes):
106+
rng = jtu.rand_default(self.rng())
107+
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
108+
scipy_fun = osp_stats.wrapcauchy.logpdf
109+
lax_fun = lsp_stats.wrapcauchy.logpdf
110+
111+
def args_maker():
112+
x = rng(shapes[0], dtypes[0])
113+
c = rng_uniform(shapes[1], dtypes[1])
114+
return [x, c]
115+
116+
tol = {
117+
np.float32: 1e-4 if jtu.test_device_matches(["tpu"]) else 1e-5,
118+
np.float64: 1e-11,
119+
}
120+
with jtu.strict_promotion_if_dtypes_match(dtypes):
121+
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
122+
check_dtypes=False, tol=tol)
123+
self._CompileAndCheck(lax_fun, args_maker, tol=tol)
124+
83125
@genNamedParametersNArgs(3)
84126
def testPoissonLogPmf(self, shapes, dtypes):
85127
rng = jtu.rand_default(self.rng())

0 commit comments

Comments
 (0)