1818from scipy .stats import norm , poisson
1919
2020from pyhf .typing import Literal , Shape
21+ from typing import cast
2122
2223T = TypeVar ("T" , bound = NBitBase )
2324
2627log = logging .getLogger (__name__ )
2728
2829
29- class _BasicPoisson :
30+ class _BasicPoisson ( Generic [ T ]) :
3031 def __init__ (self , rate : Tensor [T ]):
3132 self .rate = rate
3233
3334 def sample (self , sample_shape : Shape ) -> ArrayLike :
34- return poisson (self .rate ).rvs (size = sample_shape + self .rate .shape ) # type: ignore[no-any-return]
35+ return cast (
36+ ArrayLike , poisson (self .rate ).rvs (size = sample_shape + self .rate .shape )
37+ )
3538
36- def log_prob (self , value : NDArray [ np . number [ T ] ]) -> ArrayLike :
39+ def log_prob (self , value : Tensor [ T ]) -> ArrayLike :
3740 tensorlib : numpy_backend [T ] = numpy_backend ()
3841 return tensorlib .poisson_logpdf (value , self .rate )
3942
4043
41- class _BasicNormal :
44+ class _BasicNormal ( Generic [ T ]) :
4245 def __init__ (self , loc : Tensor [T ], scale : Tensor [T ]):
4346 self .loc = loc
4447 self .scale = scale
4548
4649 def sample (self , sample_shape : Shape ) -> ArrayLike :
47- return norm (self .loc , self .scale ).rvs (size = sample_shape + self .loc .shape ) # type: ignore[no-any-return]
50+ return cast (
51+ ArrayLike ,
52+ norm (self .loc , self .scale ).rvs (size = sample_shape + self .loc .shape ),
53+ )
4854
49- def log_prob (self , value : NDArray [ np . number [ T ] ]) -> ArrayLike :
55+ def log_prob (self , value : Tensor [ T ]) -> ArrayLike :
5056 tensorlib : numpy_backend [T ] = numpy_backend ()
5157 return tensorlib .normal_logpdf (value , self .loc , self .scale )
5258
@@ -125,7 +131,7 @@ def erf(self, tensor_in: Tensor[T]) -> ArrayLike:
125131 Returns:
126132 NumPy ndarray: The values of the error function at the given points.
127133 """
128- return special .erf (tensor_in ) # type: ignore[no-any-return]
134+ return cast ( ArrayLike , special .erf (tensor_in ))
129135
130136 def erfinv (self , tensor_in : Tensor [T ]) -> ArrayLike :
131137 """
@@ -145,7 +151,7 @@ def erfinv(self, tensor_in: Tensor[T]) -> ArrayLike:
145151 Returns:
146152 NumPy ndarray: The values of the inverse of the error function at the given points.
147153 """
148- return special .erfinv (tensor_in ) # type: ignore[no-any-return]
154+ return cast ( ArrayLike , special .erfinv (tensor_in ))
149155
150156 def tile (self , tensor_in : Tensor [T ], repeats : int | Sequence [int ]) -> ArrayLike :
151157 """
@@ -207,7 +213,7 @@ def tolist(self, tensor_in: Tensor[T] | list[T]) -> list[T]:
207213 raise
208214
209215 def outer (self , tensor_in_1 : Tensor [T ], tensor_in_2 : Tensor [T ]) -> ArrayLike :
210- return np .outer (tensor_in_1 , tensor_in_2 ) # type: ignore[arg-type]
216+ return cast ( ArrayLike , np .outer (tensor_in_1 , tensor_in_2 ))
211217
212218 def gather (self , tensor : Tensor [T ], indices : NDArray [np .integer [T ]]) -> ArrayLike :
213219 return tensor [indices ]
@@ -255,7 +261,7 @@ def sum(self, tensor_in: Tensor[T], axis: int | None = None) -> ArrayLike:
255261 return np .sum (tensor_in , axis = axis )
256262
257263 def product (self , tensor_in : Tensor [T ], axis : Shape | None = None ) -> ArrayLike :
258- return np .prod (tensor_in , axis = axis ) # type: ignore[arg-type]
264+ return cast ( ArrayLike , np .prod (tensor_in , axis = axis ))
259265
260266 def abs (self , tensor : Tensor [T ]) -> ArrayLike :
261267 return np .abs (tensor )
@@ -345,7 +351,7 @@ def percentile(
345351 .. versionadded:: 0.7.0
346352 """
347353 # see https://github.com/numpy/numpy/issues/22125
348- return np .percentile (tensor_in , q , axis = axis , interpolation = interpolation ) # type: ignore[call-overload,no-any-return ]
354+ return cast ( ArrayLike , np .percentile (tensor_in , q , axis = axis , interpolation = interpolation )) # type: ignore[call-overload]
349355
350356 def stack (self , sequence : Sequence [Tensor [T ]], axis : int = 0 ) -> ArrayLike :
351357 return np .stack (sequence , axis = axis )
@@ -392,7 +398,7 @@ def simple_broadcast(self, *args: Sequence[Tensor[T]]) -> Sequence[Tensor[T]]:
392398 return np .broadcast_arrays (* args )
393399
394400 def shape (self , tensor : Tensor [T ]) -> Shape :
395- return tensor .shape
401+ return cast ( Shape , tensor .shape )
396402
397403 def reshape (self , tensor : Tensor [T ], newshape : Shape ) -> ArrayLike :
398404 return np .reshape (tensor , newshape )
@@ -434,10 +440,10 @@ def einsum(self, subscripts: str, *operands: Sequence[Tensor[T]]) -> ArrayLike:
434440 Returns:
435441 tensor: the calculation based on the Einstein summation convention
436442 """
437- return np .einsum (subscripts , * operands ) # type: ignore[arg-type,no-any-return]
443+ return cast ( ArrayLike , np .einsum (subscripts , * operands ))
438444
439445 def poisson_logpdf (self , n : Tensor [T ], lam : Tensor [T ]) -> ArrayLike :
440- return xlogy (n , lam ) - lam - gammaln (n + 1.0 ) # type: ignore[no-any-return]
446+ return cast ( ArrayLike , xlogy (n , lam ) - lam - gammaln (n + 1.0 ))
441447
442448 def poisson (self , n : Tensor [T ], lam : Tensor [T ]) -> ArrayLike :
443449 r"""
@@ -481,7 +487,7 @@ def poisson(self, n: Tensor[T], lam: Tensor[T]) -> ArrayLike:
481487 """
482488 _n = np .asarray (n )
483489 _lam = np .asarray (lam )
484- return np .exp (xlogy (_n , _lam ) - _lam - gammaln (_n + 1.0 )) # type: ignore[no-any-return,operator]
490+ return cast ( ArrayLike , np .exp (xlogy (_n , _lam ) - _lam - gammaln (_n + 1 )))
485491
486492 def normal_logpdf (self , x : Tensor [T ], mu : Tensor [T ], sigma : Tensor [T ]) -> ArrayLike :
487493 # this is much faster than
@@ -491,7 +497,7 @@ def normal_logpdf(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> ArrayL
491497 root2pi = np .sqrt (2 * np .pi )
492498 prefactor = - np .log (sigma * root2pi )
493499 summand = - np .square (np .divide ((x - mu ), (root2 * sigma )))
494- return prefactor + summand # type: ignore[no-any-return]
500+ return cast ( ArrayLike , prefactor + summand )
495501
496502 # def normal_logpdf(self, x, mu, sigma):
497503 # return norm.logpdf(x, loc=mu, scale=sigma)
@@ -522,7 +528,7 @@ def normal(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> ArrayLike:
522528 Returns:
523529 NumPy float: Value of Normal(x|mu, sigma)
524530 """
525- return norm .pdf (x , loc = mu , scale = sigma ) # type: ignore[no-any-return]
531+ return cast ( ArrayLike , norm .pdf (x , loc = mu , scale = sigma ))
526532
527533 def normal_cdf (
528534 self , x : Tensor [T ], mu : float | Tensor [T ] = 0 , sigma : float | Tensor [T ] = 1
@@ -548,9 +554,9 @@ def normal_cdf(
548554 Returns:
549555 NumPy float: The CDF
550556 """
551- return norm .cdf (x , loc = mu , scale = sigma ) # type: ignore[no-any-return]
557+ return cast ( ArrayLike , norm .cdf (x , loc = mu , scale = sigma ))
552558
553- def poisson_dist (self , rate : Tensor [T ]) -> _BasicPoisson :
559+ def poisson_dist (self , rate : Tensor [T ]) -> _BasicPoisson [ T ] :
554560 r"""
555561 The Poisson distribution with rate parameter :code:`rate`.
556562
@@ -571,7 +577,7 @@ def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson:
571577 """
572578 return _BasicPoisson (rate )
573579
574- def normal_dist (self , mu : Tensor [T ], sigma : Tensor [T ]) -> _BasicNormal :
580+ def normal_dist (self , mu : Tensor [T ], sigma : Tensor [T ]) -> _BasicNormal [ T ] :
575581 r"""
576582 The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.
577583
0 commit comments