Skip to content

Commit d525119

Browse files
committed
wasm3: validate function type
1 parent 25c1061 commit d525119

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

test/utils/wasm3_engine.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
#include <test/utils/adler32.hpp>
88
#include <test/utils/wasm_engine.hpp>
9+
#include <algorithm>
910
#include <cassert>
1011
#include <cstring>
12+
#include <stdexcept>
1113

1214
namespace fizzy::test
1315
{
@@ -115,14 +117,30 @@ fizzy::bytes_view Wasm3Engine::get_memory() const
115117
}
116118

117119
std::optional<WasmEngine::FuncRef> Wasm3Engine::find_function(
118-
std::string_view name, std::string_view) const
120+
std::string_view name, std::string_view signature) const
119121
{
120122
IM3Function function;
121-
if (m3_FindFunction(&function, m_runtime, name.data()) == m3Err_none)
122-
// TODO: validate input/output types
123-
// (m3_GetArgCount/m3_GetArgType/m3_GetRetCount/m3_GetRetType)
124-
return reinterpret_cast<WasmEngine::FuncRef>(function);
125-
return std::nullopt;
123+
if (m3_FindFunction(&function, m_runtime, name.data()) != m3Err_none)
124+
return std::nullopt;
125+
126+
std::vector<M3ValueType> inputs;
127+
std::vector<M3ValueType> outputs;
128+
std::tie(inputs, outputs) = translate_function_signature<M3ValueType, M3ValueType::c_m3Type_i32,
129+
M3ValueType::c_m3Type_i64>(signature);
130+
131+
if (inputs.size() != m3_GetArgCount(function))
132+
return std::nullopt;
133+
for (unsigned i = 0; i < m3_GetArgCount(function); i++)
134+
if (inputs[i] != m3_GetArgType(function, i))
135+
return std::nullopt;
136+
137+
if (outputs.size() != m3_GetRetCount(function))
138+
return std::nullopt;
139+
for (unsigned i = 0; i < m3_GetRetCount(function); i++)
140+
if (outputs[i] != m3_GetRetType(function, i))
141+
return std::nullopt;
142+
143+
return reinterpret_cast<WasmEngine::FuncRef>(function);
126144
}
127145

128146
WasmEngine::Result Wasm3Engine::execute(
@@ -137,7 +155,7 @@ WasmEngine::Result Wasm3Engine::execute(
137155

138156
// This ensures input count/type matches. For the return value we assume find_function did the
139157
// validation.
140-
if (m3_Call(function, static_cast<uint32_t>(args.size()), argPtrs.data()) == m3Err_none)
158+
if (m3_Call(function, static_cast<uint32_t>(argPtrs.size()), argPtrs.data()) == m3Err_none)
141159
{
142160
if (m3_GetRetCount(function) == 0)
143161
return {false, std::nullopt};

0 commit comments

Comments
 (0)