Skip to content

Commit 74ac780

Browse files
committed
fix mypy errors
1 parent 786a4cc commit 74ac780

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

src/pyhf/tensor/numpy_backend.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from scipy.stats import norm, poisson
1919

2020
from pyhf.typing import Literal, Shape
21+
from typing import cast
2122

2223
T = TypeVar("T", bound=NBitBase)
2324

@@ -26,27 +27,32 @@
2627
log = logging.getLogger(__name__)
2728

2829

29-
class _BasicPoisson:
30+
class _BasicPoisson(Generic[T]):
3031
def __init__(self, rate: Tensor[T]):
3132
self.rate = rate
3233

3334
def sample(self, sample_shape: Shape) -> ArrayLike:
34-
return poisson(self.rate).rvs(size=sample_shape + self.rate.shape) # type: ignore[no-any-return]
35+
return cast(
36+
ArrayLike, poisson(self.rate).rvs(size=sample_shape + self.rate.shape)
37+
)
3538

36-
def log_prob(self, value: NDArray[np.number[T]]) -> ArrayLike:
39+
def log_prob(self, value: Tensor[T]) -> ArrayLike:
3740
tensorlib: numpy_backend[T] = numpy_backend()
3841
return tensorlib.poisson_logpdf(value, self.rate)
3942

4043

41-
class _BasicNormal:
44+
class _BasicNormal(Generic[T]):
4245
def __init__(self, loc: Tensor[T], scale: Tensor[T]):
4346
self.loc = loc
4447
self.scale = scale
4548

4649
def sample(self, sample_shape: Shape) -> ArrayLike:
47-
return norm(self.loc, self.scale).rvs(size=sample_shape + self.loc.shape) # type: ignore[no-any-return]
50+
return cast(
51+
ArrayLike,
52+
norm(self.loc, self.scale).rvs(size=sample_shape + self.loc.shape),
53+
)
4854

49-
def log_prob(self, value: NDArray[np.number[T]]) -> ArrayLike:
55+
def log_prob(self, value: Tensor[T]) -> ArrayLike:
5056
tensorlib: numpy_backend[T] = numpy_backend()
5157
return tensorlib.normal_logpdf(value, self.loc, self.scale)
5258

@@ -125,7 +131,7 @@ def erf(self, tensor_in: Tensor[T]) -> ArrayLike:
125131
Returns:
126132
NumPy ndarray: The values of the error function at the given points.
127133
"""
128-
return special.erf(tensor_in) # type: ignore[no-any-return]
134+
return cast(ArrayLike, special.erf(tensor_in))
129135

130136
def erfinv(self, tensor_in: Tensor[T]) -> ArrayLike:
131137
"""
@@ -145,7 +151,7 @@ def erfinv(self, tensor_in: Tensor[T]) -> ArrayLike:
145151
Returns:
146152
NumPy ndarray: The values of the inverse of the error function at the given points.
147153
"""
148-
return special.erfinv(tensor_in) # type: ignore[no-any-return]
154+
return cast(ArrayLike, special.erfinv(tensor_in))
149155

150156
def tile(self, tensor_in: Tensor[T], repeats: int | Sequence[int]) -> ArrayLike:
151157
"""
@@ -207,7 +213,7 @@ def tolist(self, tensor_in: Tensor[T] | list[T]) -> list[T]:
207213
raise
208214

209215
def outer(self, tensor_in_1: Tensor[T], tensor_in_2: Tensor[T]) -> ArrayLike:
210-
return np.outer(tensor_in_1, tensor_in_2) # type: ignore[arg-type]
216+
return cast(ArrayLike, np.outer(tensor_in_1, tensor_in_2))
211217

212218
def gather(self, tensor: Tensor[T], indices: NDArray[np.integer[T]]) -> ArrayLike:
213219
return tensor[indices]
@@ -255,7 +261,7 @@ def sum(self, tensor_in: Tensor[T], axis: int | None = None) -> ArrayLike:
255261
return np.sum(tensor_in, axis=axis)
256262

257263
def product(self, tensor_in: Tensor[T], axis: Shape | None = None) -> ArrayLike:
258-
return np.prod(tensor_in, axis=axis) # type: ignore[arg-type]
264+
return cast(ArrayLike, np.prod(tensor_in, axis=axis))
259265

260266
def abs(self, tensor: Tensor[T]) -> ArrayLike:
261267
return np.abs(tensor)
@@ -345,7 +351,7 @@ def percentile(
345351
.. versionadded:: 0.7.0
346352
"""
347353
# see https://github.com/numpy/numpy/issues/22125
348-
return np.percentile(tensor_in, q, axis=axis, interpolation=interpolation) # type: ignore[call-overload,no-any-return]
354+
return cast(ArrayLike, np.percentile(tensor_in, q, axis=axis, interpolation=interpolation)) # type: ignore[call-overload]
349355

350356
def stack(self, sequence: Sequence[Tensor[T]], axis: int = 0) -> ArrayLike:
351357
return np.stack(sequence, axis=axis)
@@ -392,7 +398,7 @@ def simple_broadcast(self, *args: Sequence[Tensor[T]]) -> Sequence[Tensor[T]]:
392398
return np.broadcast_arrays(*args)
393399

394400
def shape(self, tensor: Tensor[T]) -> Shape:
395-
return tensor.shape
401+
return cast(Shape, tensor.shape)
396402

397403
def reshape(self, tensor: Tensor[T], newshape: Shape) -> ArrayLike:
398404
return np.reshape(tensor, newshape)
@@ -434,10 +440,10 @@ def einsum(self, subscripts: str, *operands: Sequence[Tensor[T]]) -> ArrayLike:
434440
Returns:
435441
tensor: the calculation based on the Einstein summation convention
436442
"""
437-
return np.einsum(subscripts, *operands) # type: ignore[arg-type,no-any-return]
443+
return cast(ArrayLike, np.einsum(subscripts, *operands))
438444

439445
def poisson_logpdf(self, n: Tensor[T], lam: Tensor[T]) -> ArrayLike:
440-
return xlogy(n, lam) - lam - gammaln(n + 1.0) # type: ignore[no-any-return]
446+
return cast(ArrayLike, xlogy(n, lam) - lam - gammaln(n + 1.0))
441447

442448
def poisson(self, n: Tensor[T], lam: Tensor[T]) -> ArrayLike:
443449
r"""
@@ -481,7 +487,7 @@ def poisson(self, n: Tensor[T], lam: Tensor[T]) -> ArrayLike:
481487
"""
482488
_n = np.asarray(n)
483489
_lam = np.asarray(lam)
484-
return np.exp(xlogy(_n, _lam) - _lam - gammaln(_n + 1.0)) # type: ignore[no-any-return,operator]
490+
return cast(ArrayLike, np.exp(xlogy(_n, _lam) - _lam - gammaln(_n + 1)))
485491

486492
def normal_logpdf(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> ArrayLike:
487493
# this is much faster than
@@ -491,7 +497,7 @@ def normal_logpdf(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> ArrayL
491497
root2pi = np.sqrt(2 * np.pi)
492498
prefactor = -np.log(sigma * root2pi)
493499
summand = -np.square(np.divide((x - mu), (root2 * sigma)))
494-
return prefactor + summand # type: ignore[no-any-return]
500+
return cast(ArrayLike, prefactor + summand)
495501

496502
# def normal_logpdf(self, x, mu, sigma):
497503
# return norm.logpdf(x, loc=mu, scale=sigma)
@@ -522,7 +528,7 @@ def normal(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> ArrayLike:
522528
Returns:
523529
NumPy float: Value of Normal(x|mu, sigma)
524530
"""
525-
return norm.pdf(x, loc=mu, scale=sigma) # type: ignore[no-any-return]
531+
return cast(ArrayLike, norm.pdf(x, loc=mu, scale=sigma))
526532

527533
def normal_cdf(
528534
self, x: Tensor[T], mu: float | Tensor[T] = 0, sigma: float | Tensor[T] = 1
@@ -548,9 +554,9 @@ def normal_cdf(
548554
Returns:
549555
NumPy float: The CDF
550556
"""
551-
return norm.cdf(x, loc=mu, scale=sigma) # type: ignore[no-any-return]
557+
return cast(ArrayLike, norm.cdf(x, loc=mu, scale=sigma))
552558

553-
def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson:
559+
def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson[T]:
554560
r"""
555561
The Poisson distribution with rate parameter :code:`rate`.
556562
@@ -571,7 +577,7 @@ def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson:
571577
"""
572578
return _BasicPoisson(rate)
573579

574-
def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal:
580+
def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal[T]:
575581
r"""
576582
The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.
577583

0 commit comments

Comments
 (0)