Skip to content

Commit f7f4fd9

Browse files
committed
Use a tokio::task::LocalSet.
1 parent f66b357 commit f7f4fd9

File tree

2 files changed

+58
-42
lines changed

2 files changed

+58
-42
lines changed

src/worker/listener.rs

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ struct ActionInput<T> {
5151
#[allow(clippy::too_many_arguments)]
5252
async fn handle_start_step_run(
5353
action_function_task_join_set: &mut tokio::task::JoinSet<crate::InternalResult<()>>,
54+
local_set: &tokio::task::LocalSet,
5455
abort_handles: &mut HashMap<String, tokio::task::AbortHandle>,
5556
dispatcher: &mut DispatcherClient<
5657
tonic::service::interceptor::InterceptedService<
@@ -103,43 +104,56 @@ async fn handle_start_step_run(
103104

104105
let worker_id = worker_id.to_string();
105106
let step_run_id = action.step_run_id.clone();
106-
let abort_handle = action_function_task_join_set.spawn_local(async move {
107-
let context = Context::new(
108-
input.parents,
109-
workflow_run_id,
110-
workflow_step_run_id,
111-
workflow_service_client,
112-
data,
113-
);
114-
let action_event = match action_callable(context, input.input).catch_unwind().await {
115-
Ok(Ok(output_value)) => step_action_event(
116-
&worker_id,
117-
&action,
118-
StepActionEventType::StepEventTypeCompleted,
119-
serde_json::to_string(&output_value).expect("must succeed"),
120-
),
121-
Ok(Err(error)) => step_action_event(
122-
&worker_id,
123-
&action,
124-
StepActionEventType::StepEventTypeFailed,
125-
error.to_string(),
126-
),
127-
Err(_) => step_action_event(
128-
&worker_id,
129-
&action,
130-
StepActionEventType::StepEventTypeFailed,
131-
"action panicked".to_owned(),
132-
),
133-
};
134-
135-
dispatcher
136-
.send_step_action_event(action_event)
137-
.await
138-
.map_err(crate::InternalError::CouldNotSendStepStatus)?
139-
.into_inner();
140-
141-
Ok(())
142-
});
107+
let abort_handle = action_function_task_join_set.spawn_local_on(
108+
async move {
109+
let context = Context::new(
110+
input.parents,
111+
workflow_run_id,
112+
workflow_step_run_id,
113+
workflow_service_client,
114+
data,
115+
);
116+
let action_event = match action_callable(context, input.input).catch_unwind().await {
117+
Ok(Ok(output_value)) => step_action_event(
118+
&worker_id,
119+
&action,
120+
StepActionEventType::StepEventTypeCompleted,
121+
serde_json::to_string(&output_value).expect("must succeed"),
122+
),
123+
Ok(Err(error)) => step_action_event(
124+
&worker_id,
125+
&action,
126+
StepActionEventType::StepEventTypeFailed,
127+
error.to_string(),
128+
),
129+
Err(error) => {
130+
let message = error
131+
.downcast_ref::<&str>()
132+
.map(|value| value.to_string())
133+
.or_else(|| error.downcast_ref::<String>().map(|value| value.clone()));
134+
step_action_event(
135+
&worker_id,
136+
&action,
137+
StepActionEventType::StepEventTypeFailed,
138+
message.unwrap_or_else(|| {
139+
String::from(
140+
"task panicked with a payload that was not a `&str` nor a `String`",
141+
)
142+
}),
143+
)
144+
}
145+
};
146+
147+
dispatcher
148+
.send_step_action_event(action_event)
149+
.await
150+
.map_err(crate::InternalError::CouldNotSendStepStatus)?
151+
.into_inner();
152+
153+
Ok(())
154+
},
155+
local_set,
156+
);
143157
abort_handles.insert(step_run_id, abort_handle);
144158

145159
Ok(())
@@ -182,14 +196,15 @@ pub(crate) async fn run(
182196
listener_v2_timeout: Option<u64>,
183197
mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>,
184198
data: Arc<DataMap>,
185-
) -> crate::InternalResult<()> {
199+
) -> crate::InternalResult<tokio::task::LocalSet> {
186200
use futures_util::StreamExt;
187201

188202
let mut retries: usize = 0;
189203
let mut listen_strategy = ListenStrategy::V2;
190204

191205
let connection_attempt = tokio::time::Instant::now();
192206

207+
let local_set = tokio::task::LocalSet::new();
193208
let mut abort_handles = HashMap::new();
194209

195210
'main_loop: loop {
@@ -252,7 +267,7 @@ pub(crate) async fn run(
252267
let action = match result {
253268
Err(status) => match status.code() {
254269
tonic::Code::Cancelled => {
255-
return Ok(());
270+
return Ok(local_set);
256271
}
257272
tonic::Code::DeadlineExceeded => {
258273
continue 'main_loop;
@@ -279,7 +294,7 @@ pub(crate) async fn run(
279294

280295
match action_type {
281296
ActionType::StartStepRun => {
282-
handle_start_step_run(action_function_task_join_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?;
297+
handle_start_step_run(action_function_task_join_set, &local_set, &mut abort_handles, &mut dispatcher, workflow_service_client.clone(), namespace, worker_id, &workflows, action, data.clone()).await?;
283298
}
284299
ActionType::CancelStepRun => {
285300
handle_cancel_step_run(&mut abort_handles, action).await?;
@@ -298,5 +313,5 @@ pub(crate) async fn run(
298313
}
299314
}
300315

301-
Ok(())
316+
Ok(local_set)
302317
}

src/worker/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,13 @@ impl Worker<'_> {
326326
let mut action_function_task_join_set =
327327
tokio::task::JoinSet::<crate::InternalResult<()>>::new();
328328

329-
futures_util::try_join! {
329+
let (_, local_set) = futures_util::try_join! {
330330
heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver),
331331
listener::run(&mut action_function_task_join_set, dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data)
332332
}?;
333333

334334
action_function_task_join_set.shutdown().await;
335+
local_set.await;
335336

336337
Ok(())
337338
}

0 commit comments

Comments
 (0)