diff --git a/src/libmodelbox/base/include/modelbox/base/executor.h b/src/libmodelbox/base/include/modelbox/base/executor.h index 84f27253f..b0eaaddea 100644 --- a/src/libmodelbox/base/include/modelbox/base/executor.h +++ b/src/libmodelbox/base/include/modelbox/base/executor.h @@ -33,6 +33,8 @@ class Executor { virtual ~Executor(); + void SetThreadCount(int thread_count); + template auto Run(func &&fun, int32_t priority, ts &&...params) -> std::future::type> { diff --git a/src/libmodelbox/engine/flowunit_data_executor.cc b/src/libmodelbox/engine/flowunit_data_executor.cc index e6633d3e4..2eeed9d0a 100644 --- a/src/libmodelbox/engine/flowunit_data_executor.cc +++ b/src/libmodelbox/engine/flowunit_data_executor.cc @@ -34,6 +34,10 @@ Executor::Executor(int thread_count) { Executor::~Executor() { thread_pool_ = nullptr; } +void Executor::SetThreadCount(int thread_count){ + thread_pool_->SetThreadSize(thread_count); +} + FlowUnitExecContext::FlowUnitExecContext( std::shared_ptr data_ctx) : data_ctx_(std::move(data_ctx)) {} diff --git a/src/libmodelbox/engine/flowunit_group.cc b/src/libmodelbox/engine/flowunit_group.cc index 5995bcc4f..b45b9d48e 100644 --- a/src/libmodelbox/engine/flowunit_group.cc +++ b/src/libmodelbox/engine/flowunit_group.cc @@ -46,6 +46,11 @@ void FlowUnitGroup::InitTrace() { } } +uint32_t FlowUnitGroup::GetBatchSize() const +{ + return batch_size_; +} + std::shared_ptr FlowUnitGroup::StartTrace( FUExecContextList &exec_ctx_list) { std::call_once(trace_init_flag_, &FlowUnitGroup::InitTrace, this); diff --git a/src/libmodelbox/engine/flowunit_manager.cc b/src/libmodelbox/engine/flowunit_manager.cc index a72786ef9..24d09a27d 100644 --- a/src/libmodelbox/engine/flowunit_manager.cc +++ b/src/libmodelbox/engine/flowunit_manager.cc @@ -59,6 +59,13 @@ Status FlowUnitManager::Initialize( SetDeviceManager(std::move(device_mgr)); Status status; status = InitFlowUnitFactory(driver); + + if (config != nullptr){ + max_executor_thread_num = config->GetUint32("graph.max_executor_thread_num", 0); + } else { + max_executor_thread_num = 0; + } + if (status != STATUS_SUCCESS) { return status; } @@ -407,6 +414,9 @@ std::shared_ptr FlowUnitManager::CreateSingleFlowUnit( return nullptr; } + MBLOG_INFO << "max_executor_thread_num: " << max_executor_thread_num; + device->GetDeviceExecutor()->SetThreadCount(max_executor_thread_num); + flowunit->SetBindDevice(device); std::vector &in_list = flowunit_desc->GetFlowUnitInput(); for (auto &in_item : in_list) { diff --git a/src/libmodelbox/engine/node.cc b/src/libmodelbox/engine/node.cc index 9d7b52d56..52ffc2e9f 100644 --- a/src/libmodelbox/engine/node.cc +++ b/src/libmodelbox/engine/node.cc @@ -762,30 +762,48 @@ void Node::CleanDataContext() { } Status Node::Run(RunType type) { + std::list> data_ctx_list; + size_t process_count = 0; auto ret = Recv(type, data_ctx_list); - if (!ret) { - return ret; - } - ret = Process(data_ctx_list); if (!ret) { return ret; } - if (!GetOutputNames().empty()) { - ret = Send(data_ctx_list); + std::list> process_ctx_list; + + for(auto& ctx: data_ctx_list){ + + process_count++; + process_ctx_list.push_back(ctx); + + if (process_ctx_list.size() < flowunit_group_->GetBatchSize()){ + if (process_count < data_ctx_list.size()){ + continue; + } + } + + ret = Process(process_ctx_list); if (!ret) { return ret; } - } else { - SetLastError(data_ctx_list); + + if (!GetOutputNames().empty()) { + ret = Send(process_ctx_list); + if (!ret) { + return ret; + } + } else { + SetLastError(process_ctx_list); + } + + process_ctx_list.clear(); } Clean(data_ctx_list); return STATUS_SUCCESS; } - void Node::SetLastError( std::list>& data_ctx_list) { for (auto& data_ctx : data_ctx_list) { diff --git a/src/libmodelbox/include/modelbox/flowunit.h b/src/libmodelbox/include/modelbox/flowunit.h index bb8a8726f..846cfa2f7 100644 --- a/src/libmodelbox/include/modelbox/flowunit.h +++ b/src/libmodelbox/include/modelbox/flowunit.h @@ -612,6 +612,8 @@ class FlowUnitManager { std::shared_ptr GetDeviceManager(); + int max_executor_thread_num; + private: Status CheckParams(const std::string &unit_name, const std::string &unit_type, const std::string &unit_device_id); diff --git a/src/libmodelbox/include/modelbox/flowunit_group.h b/src/libmodelbox/include/modelbox/flowunit_group.h index 341dfbc29..1bf1ba2c2 100644 --- a/src/libmodelbox/include/modelbox/flowunit_group.h +++ b/src/libmodelbox/include/modelbox/flowunit_group.h @@ -64,6 +64,8 @@ class FlowUnitGroup { Status Close(); + uint32_t GetBatchSize() const; + private: std::weak_ptr node_; uint32_t batch_size_;