Skip to content

Commit 6fddf2b

Browse files
authored
[mlir-python-bindings-wasm] add runtime example (#274)
1 parent e4a4b60 commit 6fddf2b

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ Currently, there are five components:
3535

3636
We currently provide two online playgrounds where you can try out the WebAssembly version of mlir-python-bindings and eudsl-python-extras directly in your browser:
3737

38-
* https://llvm.github.io/eudsl/jupyter/ – A JupyterLite instance with a Pyodide kernel. You can install the MLIR Python bindings with: `await piplite.install("mlir-python-bindings")`.
38+
* [https://llvm.github.io/eudsl/jupyter](https://llvm.github.io/eudsl/jupyter/lab/index.html?path=mlir-python-starter.ipynb) – A JupyterLite instance with a Pyodide kernel. You can install the MLIR Python bindings with: `await piplite.install("mlir-python-bindings")`.
3939

40-
* https://llvm.github.io/eudsl/console/ – A Pyodide-based REPL with `mlir-python-bindings` and `eudsl-python-extras` preloaded. Just run: `from mlir.ir import *` to start coding.
40+
* https://llvm.github.io/eudsl/console – A Pyodide-based REPL with `mlir-python-bindings` and `eudsl-python-extras` preloaded. Just run: `from mlir.ir import *` to start coding.
4141

4242
## Getting started
4343

pages/jupyter/contents/mlir-python-starter.ipynb

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,27 @@
2424
{
2525
"id": "37e7b57b-06eb-4249-a88b-dfbb7dba510e",
2626
"cell_type": "code",
27-
"source": "%%capture\n\nimport piplite\nawait piplite.install('mlir-python-bindings')",
27+
"source": "%%capture\n\nimport piplite\nawait piplite.install('mlir-python-bindings')\nawait piplite.install('numpy')",
2828
"metadata": {
2929
"trusted": true
3030
},
3131
"outputs": [],
3232
"execution_count": 1
3333
},
34+
{
35+
"id": "bd12373c-7e45-4df1-a92b-a4260e2129d2",
36+
"cell_type": "code",
37+
"source": "import gc\nimport sys\nfrom pathlib import Path\nfrom textwrap import dedent\nimport ctypes\nimport numpy as np\n\nfrom mlir import _mlir_libs\n\nfrom mlir.wasm_execution_engine import (\n _mlirWasmExecutionEngine,\n WasmExecutionEngine,\n)\nfrom mlir.runtime.np_to_memref import get_ranked_memref_descriptor, as_ctype\nfrom mlir.ir import *\nfrom mlir.passmanager import *\nfrom mlir.dialects import func, arith",
38+
"metadata": {
39+
"trusted": true
40+
},
41+
"outputs": [],
42+
"execution_count": 4
43+
},
3444
{
3545
"id": "403e67d0-cda2-41ae-82e2-b634fa9c6739",
3646
"cell_type": "code",
37-
"source": "from mlir.ir import *\nfrom mlir.dialects import func, arith\n\nwith Context(), Location.unknown():\n f32 = F32Type.get()\n\n module = Module.create()\n with InsertionPoint(module.body):\n @func.func(f32, f32)\n def add(x, y):\n return arith.addf(x, y)\n\nprint(module)",
47+
"source": "with Context(), Location.unknown():\n f32 = F32Type.get()\n\n module = Module.create()\n with InsertionPoint(module.body):\n @func.func(f32, f32)\n def add(x, y):\n return arith.addf(x, y)\n\nprint(module)",
3848
"metadata": {
3949
"trusted": true
4050
},
@@ -45,11 +55,37 @@
4555
"text": "module {\n func.func @add(%arg0: f32, %arg1: f32) -> f32 {\n %0 = arith.addf %arg0, %arg1 : f32\n return %0 : f32\n }\n}\n\n"
4656
}
4757
],
48-
"execution_count": 2
58+
"execution_count": 5
4959
},
5060
{
5161
"id": "a40a3cfb-1990-41e0-b9d3-1f5dc669949a",
5262
"cell_type": "code",
63+
"source": "def log(*args):\n print(*args, file=sys.stderr)\n sys.stderr.flush()\n\n\ndef run(f):\n log(\"\\nTEST:\", f.__name__)\n f()\n gc.collect()\n\n\n@run\ndef testapis():\n with Context():\n module = Module.parse(\n dedent(\n \"\"\"\n module attributes {llvm.target_triple = \"wasm32-unknown-emscripten\"} {\n llvm.func @none(%arg0: i32) -> i32 {\n %0 = llvm.mlir.constant(333 : i32) : i32\n %t0 = llvm.add %arg0, %0 : i32\n llvm.return %t0 : i32\n }\n }\n \"\"\"\n )\n )\n wasm_ee = WasmExecutionEngine(module.operation)\n func = _mlirWasmExecutionEngine.get_symbol_address(\"none\")\n func = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int)(func)\n assert func(20) == 353\n\n\ndef lowerToLLVM(module):\n pm = PassManager.parse(\n \"builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm{index-bitwidth=32},convert-func-to-llvm{index-bitwidth=32},convert-arith-to-llvm{index-bitwidth=32},convert-cf-to-llvm{index-bitwidth=32},reconcile-unrealized-casts)\"\n )\n pm.run(module.operation)\n return module\n\n\n@run\ndef testMemrefAdd():\n with Context():\n module = Module.parse(\n dedent(\n \"\"\"\n module {\n func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n %2 = memref.load %arg1[] : memref<f32>\n %3 = arith.addf %1, %2 : f32\n memref.store %3, %arg2[%0] : memref<1xf32>\n return %3 : f32\n }\n func.func @main2(%arg0: memref<f32>) -> (f32) attributes { llvm.emit_c_interface } {\n %1 = memref.load %arg0[] : memref<f32>\n return %1 : f32\n }\n func.func @main3(%arg0: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n return %1 : f32\n }\n func.func @main4(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n return %1 : f32\n }\n func.func @main5(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg2: memref<f32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n return %1 : f32\n }\n func.func @main6(%arg0: memref<1xf32>, %arg2: memref<1xf32>, %arg1: memref<f32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n %2 = memref.load %arg1[] : memref<f32>\n %3 = arith.addf %1, %2 : f32\n memref.store %3, %arg2[%0] : memref<1xf32>\n return %3 : f32\n }\n func.func @main7(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n %2 = memref.load %arg1[] : memref<f32>\n %3 = arith.addf %1, %2 : f32\n memref.store %3, %arg2[%0] : memref<1xf32>\n return %3 : f32\n }\n }\n \"\"\"\n )\n )\n\n module = lowerToLLVM(module)\n\n arg1 = np.array([32.5]).astype(np.float32)\n arg2 = np.array(6).astype(np.float32)\n res = np.array([0]).astype(np.float32)\n\n arg1_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(arg1))\n )\n arg2_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(arg2))\n )\n res_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(res))\n )\n\n # print(module)\n wasm_ee = WasmExecutionEngine(module.operation, module_name=\"bar\")\n try:\n print(wasm_ee.lookup(\"main\"))\n except ValueError as e:\n assert e.args[0] == \"functions named `main` are not supported on wasm\"\n\n res_ = wasm_ee.invoke_with_return_type(\n \"_mlir_ciface_main\",\n [arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr],\n return_type=ctypes.c_float,\n )\n print(res_)\n # CHECK: [32.5] + 6.0 = [38.5]\n print(\"{0} + {1} = {2}\".format(arg1, arg2, res))\n\n ctp = as_ctype(arg2.dtype)\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main2\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n )(func)\n res_ = func(\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n )\n print(res_)\n\n ctp = as_ctype(arg2.dtype)\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main3\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n )(func)\n res_ = func(\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n )\n print(res_)\n\n size_of_void_p = ctypes.sizeof(ctypes.c_void_p)\n print(f\"The size of ctypes.c_void_p is: {size_of_void_p} bytes\")\n\n size_of_longlong = ctypes.sizeof(ctypes.c_longlong)\n print(f\"The size of ctypes.c_longlong is: {size_of_longlong} bytes\")\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main4\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # # # arg2\n # arg2.ctypes.data,\n # arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n # 0,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n )\n print(res_)\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main5\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # arg2\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # arg2\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n )\n print(res_)\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main6\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # arg2\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # arg2\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n )\n print(res_)\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main7\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # arg2\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # arg2\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n )\n print(res_)\n\n wasm_ee.invoke_with_return_type(\n \"_mlir_ciface_main7\",\n [arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr],\n return_type=ctypes.c_float,\n )\n # CHECK: [32.5] + 6.0 = [38.5]\n print(\"{0} + {1} = {2}\".format(arg1, arg2, res))\n assert res[0] == 38.5\n\n\n@run\ndef testSharedLibLoad():\n with Context():\n module = Module.parse(\n dedent(\n \"\"\"\n module {\n func.func @foo(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {\n %c0 = arith.constant 0 : index\n %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>\n call @myPrintMemrefShapeF32(%u_memref) : (memref<*xf32>) -> ()\n return\n }\n func.func private @myPrintMemrefShapeF32(memref<*xf32>) attributes { llvm.emit_c_interface }\n }\n \"\"\"\n )\n )\n\n arg0 = np.array([0.0]).astype(np.float32)\n arg0_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(arg0))\n )\n\n execution_engine = WasmExecutionEngine(\n lowerToLLVM(module),\n opt_level=3,\n shared_libs=[\n str(Path(_mlir_libs.__file__).parent / \"lib32b_mlir_runner_utils.so\")\n ],\n )\n execution_engine.invoke(\"foo\", arg0_memref_ptr)\n # Unranked Memref base@ = 0 rank = 1 offset = 0 sizes = [0] strides = [0]",
64+
"metadata": {
65+
"trusted": true
66+
},
67+
"outputs": [
68+
{
69+
"name": "stderr",
70+
"output_type": "stream",
71+
"text": "\nTEST: testapis\n\nTEST: testMemrefAdd\n"
72+
},
73+
{
74+
"name": "stdout",
75+
"output_type": "stream",
76+
"text": "4.749539006749374e-34\n[32.5] + 6.0 = [0.]\n6.0\n32.5\nThe size of ctypes.c_void_p is: 4 bytes\nThe size of ctypes.c_longlong is: 8 bytes\n32.5\n32.5\n38.5\n38.5\n[32.5] + 6.0 = [38.5]\n"
77+
},
78+
{
79+
"name": "stderr",
80+
"output_type": "stream",
81+
"text": "\nTEST: testSharedLibLoad\n"
82+
}
83+
],
84+
"execution_count": 6
85+
},
86+
{
87+
"id": "293439ff-eb60-4260-a1fb-340eb568e9b6",
88+
"cell_type": "code",
5389
"source": "",
5490
"metadata": {
5591
"trusted": true

0 commit comments

Comments
 (0)