diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index c205dc0..780a1f8 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -873,6 +873,15 @@ def get_ast_ctx(self): """Return the ast context.""" return self.ast_ctx + def __get__(self, obj, objtype=None): + """Support descriptor protocol so class attributes bind to instances.""" + if obj is None: + return self + # we use weak references when we bind the method calls to the instance inst; + # otherwise these self references cause the object to not be deleted until + # it is later garbage collected + return EvalFuncVarClassInst(self.func, self.ast_ctx, weakref.ref(obj)) + def __del__(self): """On deletion, stop any triggers for this function.""" if self.func: @@ -1983,22 +1992,6 @@ async def call_func(self, func, func_name, *args, **kwargs): if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"): has_init_wrapper = getattr(func, "__init__evalfunc_wrap__") is not None inst = func(*args, **kwargs) if not has_init_wrapper else func() - # - # we use weak references when we bind the method calls to the instance inst; - # otherwise these self references cause the object to not be deleted until - # it is later garbage collected - # - inst_weak = weakref.ref(inst) - for name in dir(inst): - try: - value = getattr(inst, name) - except AttributeError: - # same effect as hasattr (which also catches AttributeError) - # dir() may list names that aren't actually accessible attributes - continue - if type(value) is not EvalFuncVar: - continue - setattr(inst, name, EvalFuncVarClassInst(value.get_func(), value.get_ast_ctx(), inst_weak)) if has_init_wrapper: # # since our __init__ function is async, call the renamed one diff --git a/tests/test_unit_eval.py b/tests/test_unit_eval.py index 2bef2df..527e479 100644 --- a/tests/test_unit_eval.py +++ b/tests/test_unit_eval.py @@ -181,6 +181,58 @@ class Color(Enum): ], [ """ +from enum import Enum + +class HomeState(Enum): + HOME = "home" + AWAY = "away" + + def name_and_value(self): + return f"{self.name}:{self.value}" + +[HomeState.HOME.name_and_value(), HomeState.AWAY.name_and_value()] +""", + ["HOME:home", "AWAY:away"], + ], + [ + """ +class Device: + def greet(self, name): + return f"hi {name} from {self.__class__.__name__}" + +d = Device() +d.greet("Alice") +""", + "hi Alice from Device", + ], + [ + """ +def add(self, x, y=0): + return x + y + +class Calc: + pass + +Calc.add = add +Calc().add(2, 3) +""", + 5, + ], + [ + """ +class Base: + def tag(self): + return "base" + +class Child(Base): + pass + +Child().tag() +""", + "base", + ], + [ + """ from dataclasses import dataclass @dataclass()