Skip to content

Commit fa4807b

Browse files
authored
[eudsl-python-extras] remove bytecode dep (#240)
This PR removes our dependency on `bytecode` (which turned out to be completely superfluous 🤦).
1 parent 8dd0292 commit fa4807b

File tree

5 files changed

+27
-63
lines changed

5 files changed

+27
-63
lines changed

projects/eudsl-python-extras/mlir/extras/ast/canonicalize.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from opcode import opmap
1313
from typing import List, Union, Sequence, get_type_hints
1414

15-
from bytecode import ConcreteBytecode
1615

1716
from ..ast.util import get_module_cst, set_lineno, find_func_in_code_object
1817

@@ -166,25 +165,23 @@ def to_str(self: OpCode):
166165
setattr(OpCode, "__str__", to_str)
167166

168167

169-
class BytecodePatcher(ABC):
168+
class FunctionPatcher(ABC):
170169
def __init__(self, context=None):
171170
self.context = context
172171

173172
@abstractmethod
174-
def patch_bytecode(self, code: ConcreteBytecode, original_f) -> ConcreteBytecode:
173+
def patch_function(self, original_f):
175174
pass
176175

177176

178-
def patch_bytecode(f, patchers: List[type(BytecodePatcher)] = None):
177+
def patch_function(f, patchers: List[type(FunctionPatcher)] = None):
179178
if patchers is None:
180179
return f
181-
code = ConcreteBytecode.from_code(f.__code__)
182180
context = types.SimpleNamespace()
183181
for patcher in patchers:
184-
code = patcher(context).patch_bytecode(code, f)
182+
new_f = patcher(context).patch_function(f)
185183

186-
f.__code__ = code.to_code()
187-
return f
184+
return new_f
188185

189186

190187
class Canonicalizer(ABC):
@@ -195,22 +192,22 @@ def cst_transformers(self) -> List[StrictTransformer]:
195192

196193
@property
197194
@abstractmethod
198-
def bytecode_patchers(self) -> List[BytecodePatcher]:
195+
def function_patchers(self) -> List[FunctionPatcher]:
199196
pass
200197

201198

202199
def canonicalize(*, using: Union[Canonicalizer, Sequence[Canonicalizer]]):
203200
if not isinstance(using, Sequence):
204201
using = [using]
205202
cst_transformers = []
206-
bytecode_patchers = []
203+
function_patchers = []
207204
for u in using:
208205
cst_transformers.extend(u.cst_transformers)
209-
bytecode_patchers.extend(u.bytecode_patchers)
206+
function_patchers.extend(u.function_patchers)
210207

211208
def wrapper(f):
212209
f = transform_ast(f, cst_transformers)
213-
f = patch_bytecode(f, bytecode_patchers)
210+
f = patch_function(f, function_patchers)
214211
return f
215212

216213
return wrapper

projects/eudsl-python-extras/mlir/extras/ast/util.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
import inspect
77
import io
88
import types
9-
from opcode import opmap
109
from textwrap import dedent
1110
from typing import Dict
1211

13-
from bytecode import ConcreteBytecode
1412
from cloudpickle import cloudpickle
1513

1614
from ...ir import Type
@@ -87,39 +85,6 @@ def make_cell(value=_empty_cell_value):
8785
return cell
8886

8987

90-
# based on https://github.com/python/cpython/blob/a4b44d39cd6941cc03590fee7538776728bdfd0a/Lib/test/test_code.py#L197
91-
def replace_closure(code, new_closure: Dict):
92-
COPY_FREE_VARS = opmap["COPY_FREE_VARS"]
93-
LOAD_DEREF = opmap["LOAD_DEREF"]
94-
95-
# get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i
96-
localsplus, _localsplus_name_to_idx = get_localsplus_name_to_idx(code)
97-
98-
# closure vars go into co_freevars
99-
new_code = code.replace(co_freevars=tuple(new_closure.keys()))
100-
# closure is a tuple of cells
101-
closure = tuple(
102-
make_cell(v) if not isinstance(v, types.CellType) else v
103-
for v in new_closure.values()
104-
)
105-
106-
new_code = ConcreteBytecode.from_code(new_code)
107-
# update how many closure vars are loaded from frame
108-
# see https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Python/bytecodes.c#L1571
109-
assert new_code[0].opcode == COPY_FREE_VARS, f"{new_code[0].opcode=}"
110-
new_code[0].arg = len(closure)
111-
112-
# map orig localsplus arg_i to new localplus position/arg_i
113-
new_localsplus = new_code.varnames + new_code.cellvars + new_code.freevars
114-
new_localsplus_name_to_idx = {v: i for i, v in enumerate(new_localsplus)}
115-
for c in new_code:
116-
if c.opcode == LOAD_DEREF and c.arg < len(localsplus):
117-
c.arg = new_localsplus_name_to_idx[localsplus[c.arg]]
118-
new_code = new_code.to_code()
119-
120-
return new_code, closure
121-
122-
12388
def unpickle_mlir_type(v):
12489
return Type.parse(v)
12590

@@ -145,7 +110,13 @@ def copy_object(obj):
145110
# potentially more complete approach https://stackoverflow.com/a/56901529/9045206
146111
def copy_func(f, new_closure: Dict = None):
147112
if new_closure is not None:
148-
code, closure = replace_closure(f.__code__, new_closure)
113+
# closure vars go into co_freevars
114+
code = f.__code__.replace(co_freevars=tuple(new_closure.keys()))
115+
# closure is a tuple of cells
116+
closure = tuple(
117+
make_cell(v) if not isinstance(v, types.CellType) else v
118+
for v in new_closure.values()
119+
)
149120
else:
150121
closure = copy_object(f.__closure__)
151122
code = f.__code__

projects/eudsl-python-extras/mlir/extras/dialects/arith.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from typing import Optional, Tuple, Union
1111

1212
import numpy as np
13-
from bytecode import ConcreteBytecode
1413

15-
from ..ast.canonicalize import StrictTransformer, Canonicalizer, BytecodePatcher
14+
from ..ast.canonicalize import StrictTransformer, Canonicalizer, FunctionPatcher
1615
from ..util import infer_mlir_type, mlir_type_to_np_dtype
1716
from ..._mlir_libs._mlir import register_value_caster
1817
from ...dialects import complex as complex_dialect
@@ -596,19 +595,19 @@ def visit_AugAssign(
596595
return updated_node
597596

598597

599-
class ArithPatchByteCode(BytecodePatcher):
600-
def patch_bytecode(self, code: ConcreteBytecode, f):
598+
class ArithPatchFunction(FunctionPatcher):
599+
def patch_function(self, f):
601600
# TODO(max): this is bad and should be in the closure rather than as a global
602601
from ...dialects import arith, math
603602

604603
f.__globals__["math_dialect"] = math
605604
f.__globals__["arith_dialect"] = arith
606-
return code
605+
return f
607606

608607

609608
class ArithCanonicalizer(Canonicalizer):
610609
cst_transformers = [CanonicalizeFMA]
611-
bytecode_patchers = [ArithPatchByteCode]
610+
function_patchers = [ArithPatchFunction]
612611

613612

614613
canonicalizer = ArithCanonicalizer()

projects/eudsl-python-extras/mlir/extras/dialects/scf.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77
from copy import deepcopy
88
from typing import List, Union, Optional, Sequence
99

10-
from bytecode import ConcreteBytecode
11-
1210
from .arith import constant as _ext_arith_constant, index_cast
13-
from ..ast.canonicalize import BytecodePatcher, Canonicalizer, StrictTransformer
11+
from ..ast.canonicalize import FunctionPatcher, Canonicalizer, StrictTransformer
1412
from ..ast.util import ast_call, set_lineno, append_hidden_node
1513
from ..meta import region_op
16-
from ..util import get_user_code_loc, region_adder
14+
from ..util import region_adder
1715
from ...dialects._ods_common import (
1816
_cext,
1917
get_default_loc_context,
@@ -567,14 +565,14 @@ def visit_If(self, updated_node: ast.If) -> Union[ast.With, List[ast.With]]:
567565
return then_with
568566

569567

570-
class RemoveJumpsAndInsertGlobals(BytecodePatcher):
571-
def patch_bytecode(self, code: ConcreteBytecode, f):
568+
class RemoveJumpsAndInsertGlobals(FunctionPatcher):
569+
def patch_function(self, f):
572570
# TODO(max): this is bad and should be in the closure rather than as a global
573571
f.__globals__[yield_.__name__] = yield_
574572
f.__globals__[if_ctx_manager.__name__] = if_ctx_manager
575573
f.__globals__[else_ctx_manager.__name__] = else_ctx_manager
576574
f.__globals__[placeholder_opaque_t.__name__] = placeholder_opaque_t
577-
return code
575+
return f
578576

579577

580578
class SCFCanonicalizer(Canonicalizer):
@@ -586,7 +584,7 @@ class SCFCanonicalizer(Canonicalizer):
586584
CanonicalizeWhile,
587585
]
588586

589-
bytecode_patchers = [RemoveJumpsAndInsertGlobals]
587+
function_patchers = [RemoveJumpsAndInsertGlobals]
590588

591589

592590
canonicalizer = SCFCanonicalizer()
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
PyYAML>=5.4.0
2-
bytecode>=0.17.0
32
cloudpickle>=3.0.0
43
numpy>=1.19.5, <=2.1.2

0 commit comments

Comments
 (0)