Skip to content

Commit 5341ae3

Browse files
authored
feat: adding overflow disable option to cat axes (#883)
Signed-off-by: Henry Schreiner <[email protected]>
1 parent 0a8e283 commit 5341ae3

File tree

6 files changed

+52
-15
lines changed

6 files changed

+52
-15
lines changed

include/bh_python/axis.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,20 @@ BHP_SPECIALIZE_NAME(integer_oflow)
9595
BHP_SPECIALIZE_NAME(integer_growth)
9696
BHP_SPECIALIZE_NAME(integer_circular)
9797

98+
using category_int_none = bh::axis::category<int, metadata_t, option::none_t>;
9899
using category_int = bh::axis::category<int, metadata_t>;
99100
using category_int_growth = bh::axis::category<int, metadata_t, option::growth_t>;
100101

102+
BHP_SPECIALIZE_NAME(category_int_none)
101103
BHP_SPECIALIZE_NAME(category_int)
102104
BHP_SPECIALIZE_NAME(category_int_growth)
103105

106+
using category_str_none = bh::axis::category<std::string, metadata_t, option::none_t>;
104107
using category_str = bh::axis::category<std::string, metadata_t, option::overflow_t>;
105108
using category_str_growth
106109
= bh::axis::category<std::string, metadata_t, option::growth_t>;
107110

111+
BHP_SPECIALIZE_NAME(category_str_none)
108112
BHP_SPECIALIZE_NAME(category_str)
109113
BHP_SPECIALIZE_NAME(category_str_growth)
110114

@@ -306,7 +310,9 @@ using axis_variant = bh::axis::variant<axis::regular_uoflow,
306310
axis::category_int_growth,
307311
axis::category_str,
308312
axis::category_str_growth,
309-
axis::boolean>;
313+
axis::boolean,
314+
axis::category_int_none,
315+
axis::category_str_none>;
310316

311317
// This saves a little typing
312318
using vector_axis_variant = std::vector<axis_variant>;

src/boost_histogram/_core/axis/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class _BaseCatInt(_BaseAxis):
9999
def __iter__(self) -> Iterator[int]: ...
100100
def bin(self, arg0: int) -> int: ...
101101

102+
class category_int_none(_BaseCatInt): ...
102103
class category_int(_BaseCatInt): ...
103104
class category_int_growth(_BaseCatInt): ...
104105

@@ -107,6 +108,7 @@ class _BaseCatStr(_BaseAxis):
107108
def __iter__(self) -> Iterator[str]: ...
108109
def bin(self, arg0: int) -> str: ...
109110

111+
class category_str_none(_BaseCatStr): ...
110112
class category_str(_BaseCatStr): ...
111113
class category_str_growth(_BaseCatStr): ...
112114

src/boost_histogram/_internal/axis.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,15 @@ def _repr_args_(self) -> list[str]:
583583

584584
if self.traits.growth:
585585
ret.append("growth=True")
586-
elif self.traits.circular:
587-
ret.append("circular=True")
586+
elif not self.traits.overflow:
587+
ret.append("overflow=False")
588588

589589
ret += super()._repr_args_()
590590
return ret
591591

592592

593593
@set_module("boost_histogram.axis")
594-
@register({ca.category_str_growth, ca.category_str})
594+
@register({ca.category_str_growth, ca.category_str, ca.category_str_none})
595595
class StrCategory(BaseCategory, family=boost_histogram):
596596
__slots__ = ()
597597

@@ -601,6 +601,7 @@ def __init__(
601601
*,
602602
metadata: Any = None,
603603
growth: bool = False,
604+
overflow: bool = True,
604605
__dict__: dict[str, Any] | None = None,
605606
):
606607
"""
@@ -618,21 +619,25 @@ def __init__(
618619
growth : bool = False
619620
Allow the axis to grow if a value is encountered out of range.
620621
Be careful, the axis will grow as large as needed.
622+
overflow : bool = True
623+
Include an overflow bin for "missed" hits. Ignored if growth=True.
621624
__dict__: Optional[Dict[str, Any]] = None
622625
The full metadata dictionary
623626
"""
624627

625-
options = _opts(growth=growth)
628+
options = _opts(growth=growth, overflow=overflow)
626629

627630
ax: ca._BaseCatStr
628631

629632
# henryiii: We currently expand "abc" to "a", "b", "c" - some
630633
# Python interfaces protect against that
631634

632-
if options == {"growth"}:
635+
if "growth" in options:
633636
ax = ca.category_str_growth(tuple(categories))
634-
elif options == set():
637+
elif options == {"overflow"}:
635638
ax = ca.category_str(tuple(categories))
639+
elif not options:
640+
ax = ca.category_str_none(tuple(categories))
636641
else:
637642
raise KeyError("Unsupported collection of options")
638643

@@ -659,7 +664,7 @@ def _repr_args_(self) -> list[str]:
659664

660665

661666
@set_module("boost_histogram.axis")
662-
@register({ca.category_int, ca.category_int_growth})
667+
@register({ca.category_int, ca.category_int_growth, ca.category_int_none})
663668
class IntCategory(BaseCategory, family=boost_histogram):
664669
__slots__ = ()
665670

@@ -669,6 +674,7 @@ def __init__(
669674
*,
670675
metadata: Any = None,
671676
growth: bool = False,
677+
overflow: bool = True,
672678
__dict__: dict[str, Any] | None = None,
673679
):
674680
"""
@@ -686,17 +692,21 @@ def __init__(
686692
growth : bool = False
687693
Allow the axis to grow if a value is encountered out of range.
688694
Be careful, the axis will grow as large as needed.
695+
overflow : bool = True
696+
Include an overflow bin for "missed" hits. Ignored if growth=True.
689697
__dict__: Optional[Dict[str, Any]] = None
690698
The full metadata dictionary
691699
"""
692700

693-
options = _opts(growth=growth)
701+
options = _opts(growth=growth, overflow=overflow)
694702
ax: ca._BaseCatInt
695703

696-
if options == {"growth"}:
704+
if "growth" in options:
697705
ax = ca.category_int_growth(tuple(categories))
698-
elif options == set():
706+
elif options == {"overflow"}:
699707
ax = ca.category_int(tuple(categories))
708+
elif not options:
709+
ax = ca.category_int_none(tuple(categories))
700710
else:
701711
raise KeyError("Unsupported collection of options")
702712

src/register_axis.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,14 @@ void register_axes(py::module& mod) {
7676
axis::integer_circular>(
7777
mod, [](auto ax) { ax.def(py::init<int, int>(), "start"_a, "stop"_a); });
7878

79-
register_axis_each<axis::category_int, axis::category_int_growth>(
79+
register_axis_each<axis::category_int,
80+
axis::category_int_growth,
81+
axis::category_int_none>(
8082
mod, [](auto ax) { ax.def(py::init<std::vector<int>>(), "categories"_a); });
8183

82-
register_axis_each<axis::category_str, axis::category_str_growth>(mod, [](auto ax) {
84+
register_axis_each<axis::category_str,
85+
axis::category_str_growth,
86+
axis::category_str_none>(mod, [](auto ax) {
8387
ax.def(py::init<std::vector<std::string>>(), "categories"_a);
8488
});
8589

tests/test_axis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@
3636
(bh.axis.Integer, (1, 2), "g", {}),
3737
(bh.axis.Integer, (1, 2), "", {"circular": True}),
3838
(bh.axis.IntCategory, ((1, 2, 3),), "", {}),
39+
(bh.axis.IntCategory, ((1, 2, 3),), "o", {}),
3940
(bh.axis.IntCategory, ((1, 2, 3),), "g", {}),
41+
(bh.axis.IntCategory, ((1, 2, 3),), "go", {}),
4042
(bh.axis.IntCategory, ((),), "g", {}),
4143
(bh.axis.StrCategory, (tuple("ABC"),), "", {}),
44+
(bh.axis.StrCategory, (tuple("ABC"),), "o", {}),
4245
(bh.axis.StrCategory, (tuple("ABC"),), "g", {}),
46+
(bh.axis.StrCategory, (tuple("ABC"),), "go", {}),
4347
(bh.axis.StrCategory, ((),), "g", {}),
4448
],
4549
)

tests/test_histogram.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,17 @@ def test_growing_cats():
208208
assert h.size == 4
209209

210210

211+
def test_noflow_cats():
212+
h = bh.Histogram(
213+
bh.axis.IntCategory([1, 2, 3], overflow=False),
214+
bh.axis.StrCategory(["hi"], overflow=False),
215+
)
216+
217+
h.fill([1, 2, 3, 4], ["hi", "ho", "hi", "ho"])
218+
219+
assert h.sum() == 2
220+
221+
211222
def test_metadata_add():
212223
h1 = bh.Histogram(
213224
bh.axis.IntCategory([1, 2, 3]), bh.axis.StrCategory(["1", "2", "3"])
@@ -655,9 +666,9 @@ def test_rebin_nd():
655666

656667

657668
# CLASSIC: This used to have metadata too, but that does not compare equal
658-
def test_pickle_0():
669+
def test_pickle_0(flow):
659670
a = bh.Histogram(
660-
bh.axis.IntCategory([0, 1, 2]),
671+
bh.axis.IntCategory([0, 1, 2], overflow=flow),
661672
bh.axis.Integer(0, 20),
662673
bh.axis.Regular(2, 0.0, 20.0, underflow=False, overflow=False),
663674
bh.axis.Variable([0.0, 1.0, 2.0]),

0 commit comments

Comments
 (0)