66
77#include < test/utils/adler32.hpp>
88#include < test/utils/wasm_engine.hpp>
9+ #include < algorithm>
910#include < cassert>
1011#include < cstring>
12+ #include < stdexcept>
1113
1214namespace fizzy ::test
1315{
@@ -115,14 +117,30 @@ fizzy::bytes_view Wasm3Engine::get_memory() const
115117}
116118
117119std::optional<WasmEngine::FuncRef> Wasm3Engine::find_function (
118- std::string_view name, std::string_view) const
120+ std::string_view name, std::string_view signature ) const
119121{
120122 IM3Function function;
121- if (m3_FindFunction (&function, m_runtime, name.data ()) == m3Err_none)
122- // TODO: validate input/output types
123- // (m3_GetArgCount/m3_GetArgType/m3_GetRetCount/m3_GetRetType)
124- return reinterpret_cast <WasmEngine::FuncRef>(function);
125- return std::nullopt ;
123+ if (m3_FindFunction (&function, m_runtime, name.data ()) != m3Err_none)
124+ return std::nullopt ;
125+
126+ std::vector<M3ValueType> inputs;
127+ std::vector<M3ValueType> outputs;
128+ std::tie (inputs, outputs) = translate_function_signature<M3ValueType, M3ValueType::c_m3Type_i32,
129+ M3ValueType::c_m3Type_i64>(signature);
130+
131+ if (inputs.size () != m3_GetArgCount (function))
132+ return std::nullopt ;
133+ for (unsigned i = 0 ; i < m3_GetArgCount (function); i++)
134+ if (inputs[i] != m3_GetArgType (function, i))
135+ return std::nullopt ;
136+
137+ if (outputs.size () != m3_GetRetCount (function))
138+ return std::nullopt ;
139+ for (unsigned i = 0 ; i < m3_GetRetCount (function); i++)
140+ if (outputs[i] != m3_GetRetType (function, i))
141+ return std::nullopt ;
142+
143+ return reinterpret_cast <WasmEngine::FuncRef>(function);
126144}
127145
128146WasmEngine::Result Wasm3Engine::execute (
@@ -137,7 +155,7 @@ WasmEngine::Result Wasm3Engine::execute(
137155
138156 // This ensures input count/type matches. For the return value we assume find_function did the
139157 // validation.
140- if (m3_Call (function, static_cast <uint32_t >(args .size ()), argPtrs.data ()) == m3Err_none)
158+ if (m3_Call (function, static_cast <uint32_t >(argPtrs .size ()), argPtrs.data ()) == m3Err_none)
141159 {
142160 if (m3_GetRetCount (function) == 0 )
143161 return {false , std::nullopt };
0 commit comments