+ "source": "def log(*args):\n print(*args, file=sys.stderr)\n sys.stderr.flush()\n\n\ndef run(f):\n log(\"\\nTEST:\", f.__name__)\n f()\n gc.collect()\n\n\n@run\ndef testapis():\n with Context():\n module = Module.parse(\n dedent(\n \"\"\"\n module attributes {llvm.target_triple = \"wasm32-unknown-emscripten\"} {\n llvm.func @none(%arg0: i32) -> i32 {\n %0 = llvm.mlir.constant(333 : i32) : i32\n %t0 = llvm.add %arg0, %0 : i32\n llvm.return %t0 : i32\n }\n }\n \"\"\"\n )\n )\n wasm_ee = WasmExecutionEngine(module.operation)\n func = _mlirWasmExecutionEngine.get_symbol_address(\"none\")\n func = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int)(func)\n assert func(20) == 353\n\n\ndef lowerToLLVM(module):\n pm = PassManager.parse(\n \"builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm{index-bitwidth=32},convert-func-to-llvm{index-bitwidth=32},convert-arith-to-llvm{index-bitwidth=32},convert-cf-to-llvm{index-bitwidth=32},reconcile-unrealized-casts)\"\n )\n pm.run(module.operation)\n return module\n\n\n@run\ndef testMemrefAdd():\n with Context():\n module = Module.parse(\n dedent(\n \"\"\"\n module {\n func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n %2 = memref.load %arg1[] : memref<f32>\n %3 = arith.addf %1, %2 : f32\n memref.store %3, %arg2[%0] : memref<1xf32>\n return %3 : f32\n }\n func.func @main2(%arg0: memref<f32>) -> (f32) attributes { llvm.emit_c_interface } {\n %1 = memref.load %arg0[] : memref<f32>\n return %1 : f32\n }\n func.func @main3(%arg0: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n return %1 : f32\n }\n func.func @main4(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n return %1 : f32\n }\n func.func @main5(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg2: memref<f32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n return %1 : f32\n }\n func.func @main6(%arg0: memref<1xf32>, %arg2: memref<1xf32>, %arg1: memref<f32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n %2 = memref.load %arg1[] : memref<f32>\n %3 = arith.addf %1, %2 : f32\n memref.store %3, %arg2[%0] : memref<1xf32>\n return %3 : f32\n }\n func.func @main7(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) -> (f32) attributes { llvm.emit_c_interface } {\n %0 = arith.constant 0 : index\n %1 = memref.load %arg0[%0] : memref<1xf32>\n %2 = memref.load %arg1[] : memref<f32>\n %3 = arith.addf %1, %2 : f32\n memref.store %3, %arg2[%0] : memref<1xf32>\n return %3 : f32\n }\n }\n \"\"\"\n )\n )\n\n module = lowerToLLVM(module)\n\n arg1 = np.array([32.5]).astype(np.float32)\n arg2 = np.array(6).astype(np.float32)\n res = np.array([0]).astype(np.float32)\n\n arg1_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(arg1))\n )\n arg2_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(arg2))\n )\n res_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(res))\n )\n\n # print(module)\n wasm_ee = WasmExecutionEngine(module.operation, module_name=\"bar\")\n try:\n print(wasm_ee.lookup(\"main\"))\n except ValueError as e:\n assert e.args[0] == \"functions named `main` are not supported on wasm\"\n\n res_ = wasm_ee.invoke_with_return_type(\n \"_mlir_ciface_main\",\n [arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr],\n return_type=ctypes.c_float,\n )\n print(res_)\n # CHECK: [32.5] + 6.0 = [38.5]\n print(\"{0} + {1} = {2}\".format(arg1, arg2, res))\n\n ctp = as_ctype(arg2.dtype)\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main2\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n )(func)\n res_ = func(\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n )\n print(res_)\n\n ctp = as_ctype(arg2.dtype)\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main3\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n )(func)\n res_ = func(\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n )\n print(res_)\n\n size_of_void_p = ctypes.sizeof(ctypes.c_void_p)\n print(f\"The size of ctypes.c_void_p is: {size_of_void_p} bytes\")\n\n size_of_longlong = ctypes.sizeof(ctypes.c_longlong)\n print(f\"The size of ctypes.c_longlong is: {size_of_longlong} bytes\")\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main4\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # # # arg2\n # arg2.ctypes.data,\n # arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n # 0,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n )\n print(res_)\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main5\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # arg2\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # arg2\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n )\n print(res_)\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main6\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # arg2\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # arg2\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n )\n print(res_)\n\n func = _mlirWasmExecutionEngine.get_symbol_address(\"main7\")\n func = ctypes.CFUNCTYPE(\n ctypes.c_float,\n # arg1\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n # arg2\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n # res\n ctypes.c_long,\n ctypes.POINTER(ctp),\n ctypes.c_long,\n ctypes.c_long,\n ctypes.c_long,\n )(func)\n res_ = func(\n # arg1\n arg1.ctypes.data,\n arg1.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n # arg2\n arg2.ctypes.data,\n arg2.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n # res\n res.ctypes.data,\n res.ctypes.data_as(ctypes.POINTER(ctp)),\n 0,\n 1,\n 1,\n )\n print(res_)\n\n wasm_ee.invoke_with_return_type(\n \"_mlir_ciface_main7\",\n [arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr],\n return_type=ctypes.c_float,\n )\n # CHECK: [32.5] + 6.0 = [38.5]\n print(\"{0} + {1} = {2}\".format(arg1, arg2, res))\n assert res[0] == 38.5\n\n\n@run\ndef testSharedLibLoad():\n with Context():\n module = Module.parse(\n dedent(\n \"\"\"\n module {\n func.func @foo(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {\n %c0 = arith.constant 0 : index\n %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>\n call @myPrintMemrefShapeF32(%u_memref) : (memref<*xf32>) -> ()\n return\n }\n func.func private @myPrintMemrefShapeF32(memref<*xf32>) attributes { llvm.emit_c_interface }\n }\n \"\"\"\n )\n )\n\n arg0 = np.array([0.0]).astype(np.float32)\n arg0_memref_ptr = ctypes.pointer(\n ctypes.pointer(get_ranked_memref_descriptor(arg0))\n )\n\n execution_engine = WasmExecutionEngine(\n lowerToLLVM(module),\n opt_level=3,\n shared_libs=[\n str(Path(_mlir_libs.__file__).parent / \"lib32b_mlir_runner_utils.so\")\n ],\n )\n execution_engine.invoke(\"foo\", arg0_memref_ptr)\n # Unranked Memref base@ = 0 rank = 1 offset = 0 sizes = [0] strides = [0]",
0 commit comments