Skip to content

Commit 7daa19b

Browse files
committed
Add support for attaching data to the context.
1 parent 8e5b486 commit 7daa19b

File tree

5 files changed

+99
-11
lines changed

5 files changed

+99
-11
lines changed

examples/data_in_context.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder};
2+
3+
#[derive(serde::Deserialize)]
4+
struct Input {}
5+
6+
async fn execute(context: Context, _input: Input) -> anyhow::Result<()> {
7+
assert_eq!(context.datum::<Datum>().number, 10);
8+
Ok(())
9+
}
10+
11+
struct Datum {
12+
number: usize,
13+
}
14+
15+
#[tokio::main]
16+
async fn main() -> anyhow::Result<()> {
17+
dotenv::dotenv().ok();
18+
tracing_subscriber::fmt()
19+
.with_target(false)
20+
.with_env_filter(
21+
tracing_subscriber::EnvFilter::from_default_env()
22+
.add_directive("hatchet_sdk=debug".parse()?),
23+
)
24+
.init();
25+
26+
let client = Client::new()?;
27+
let mut worker = client
28+
.worker("example_data_in_context")
29+
.datum(Datum { number: 10 })
30+
.build();
31+
worker.register_workflow(
32+
WorkflowBuilder::default()
33+
.name("example_data_in_context")
34+
.step(
35+
StepBuilder::default()
36+
.name("compute")
37+
.function(&execute)
38+
.build()?,
39+
)
40+
.build()?,
41+
);
42+
worker.start().await?;
43+
Ok(())
44+
}

examples/fibonacci.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder};
1+
use hatchet_sdk::{Client, StepBuilder, WorkflowBuilder};
22

33
fn fibonacci(n: u32) -> u32 {
44
(1..=n)

src/step_function.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
use std::{
2+
any::{Any, TypeId},
3+
collections::HashMap,
4+
sync::Arc,
5+
};
6+
17
use futures_util::lock::Mutex;
28
use tracing::info;
39

410
use crate::worker::{grpc, ServiceWithAuthorization};
511

12+
pub(crate) type DataMap = HashMap<TypeId, Box<dyn Any + Send + Sync>>;
13+
614
pub struct Context {
715
workflow_run_id: String,
816
workflow_step_run_id: String,
@@ -15,6 +23,7 @@ pub struct Context {
1523
>,
1624
u16,
1725
)>,
26+
data: Arc<DataMap>,
1827
}
1928

2029
impl Context {
@@ -27,14 +36,24 @@ impl Context {
2736
ServiceWithAuthorization,
2837
>,
2938
>,
39+
data: Arc<DataMap>,
3040
) -> Self {
3141
Self {
3242
workflow_run_id,
3343
workflow_service_client_and_spawn_index: Mutex::new((workflow_service_client, 0)),
3444
workflow_step_run_id,
45+
data,
3546
}
3647
}
3748

49+
pub fn datum<D: std::any::Any + Send + Sync>(&self) -> &D {
50+
let type_id = TypeId::of::<D>();
51+
self.data
52+
.get(&type_id)
53+
.and_then(|value| value.downcast_ref())
54+
.unwrap_or_else(|| panic!("could not find an attached datum of the type: {type_id:?}"))
55+
}
56+
3857
pub async fn trigger_workflow<I: serde::Serialize>(
3958
&self,
4059
workflow_name: &str,

src/worker/listener.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use futures_util::FutureExt;
24
use tokio::task::LocalSet;
35
use tonic::IntoRequest;
@@ -14,7 +16,7 @@ use super::{
1416
dispatcher_client::DispatcherClient, AssignedAction, StepActionEvent, StepActionEventType,
1517
WorkerListenRequest,
1618
},
17-
ListenStrategy, ServiceWithAuthorization,
19+
DataMap, ListenStrategy, ServiceWithAuthorization,
1820
};
1921

2022
const DEFAULT_ACTION_LISTENER_RETRY_INTERVAL: std::time::Duration =
@@ -63,6 +65,7 @@ async fn handle_start_step_run(
6365
worker_id: &str,
6466
workflows: &[Workflow],
6567
action: AssignedAction,
68+
data: Arc<DataMap>,
6669
) -> crate::InternalResult<()> {
6770
let Some(action_callable) = workflows
6871
.iter()
@@ -101,6 +104,7 @@ async fn handle_start_step_run(
101104
workflow_run_id,
102105
workflow_step_run_id,
103106
workflow_service_client,
107+
data,
104108
);
105109
action_callable(context, input.input).await
106110
})
@@ -155,7 +159,7 @@ pub(crate) async fn run(
155159
workflows: Vec<Workflow>,
156160
listener_v2_timeout: Option<u64>,
157161
mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>,
158-
_heartbeat_interrupt_sender: tokio::sync::mpsc::Sender<()>,
162+
data: Arc<DataMap>,
159163
) -> crate::InternalResult<()> {
160164
use futures_util::StreamExt;
161165

@@ -253,7 +257,7 @@ pub(crate) async fn run(
253257

254258
match action_type {
255259
ActionType::StartStepRun => {
256-
handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action).await?;
260+
handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action, data.clone()).await?;
257261
}
258262
ActionType::CancelStepRun => {
259263
todo!()

src/worker/mod.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
mod heartbeat;
22
mod listener;
33

4+
use std::sync::Arc;
5+
46
use grpc::{
57
CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, PutWorkflowRequest,
68
WorkerRegisterRequest, WorkerRegisterResponse, WorkflowKind,
79
};
810
use tonic::transport::Certificate;
911
use tracing::info;
1012

11-
use crate::{client::Environment, ClientTlStrategy, Workflow};
13+
use crate::{client::Environment, step_function::DataMap, ClientTlStrategy, Workflow};
1214

1315
#[derive(Clone)]
1416
pub(crate) struct ServiceWithAuthorization {
@@ -55,9 +57,18 @@ pub struct Worker<'a> {
5557
environment: &'a super::client::Environment,
5658
#[builder(default, setter(skip))]
5759
workflows: Vec<Workflow>,
60+
#[builder(default, setter(custom))]
61+
data: DataMap,
5862
}
5963

6064
impl<'a> WorkerBuilder<'a> {
65+
pub fn datum<D: std::any::Any + Send + Sync>(mut self, datum: D) -> Self {
66+
self.data
67+
.get_or_insert_default()
68+
.insert(std::any::TypeId::of::<D>(), Box::new(datum));
69+
self
70+
}
71+
6172
pub fn build(self) -> Worker<'a> {
6273
self.build_private().expect("must succeed")
6374
}
@@ -172,7 +183,7 @@ impl<'a> Worker<'a> {
172183
let (listening_interrupt_sender1, listening_interrupt_receiver) =
173184
tokio::sync::mpsc::channel(1);
174185
let _listening_interrupt_sender2 = listening_interrupt_sender1.clone();
175-
let heartbeat_interrupt_sender2 = heartbeat_interrupt_sender1.clone();
186+
let _heartbeat_interrupt_sender2 = heartbeat_interrupt_sender1.clone();
176187

177188
tokio::spawn(async move {
178189
tokio::signal::ctrl_c().await.unwrap();
@@ -182,6 +193,14 @@ impl<'a> Worker<'a> {
182193
let _ = listening_interrupt_sender1.send(()).await;
183194
});
184195

196+
let Self {
197+
workflows,
198+
data,
199+
name,
200+
max_runs,
201+
environment,
202+
} = self;
203+
185204
let Environment {
186205
token,
187206
host_port,
@@ -195,7 +214,7 @@ impl<'a> Worker<'a> {
195214
tls_server_name,
196215
namespace,
197216
listener_v2_timeout,
198-
} = self.environment;
217+
} = environment;
199218

200219
let endpoint = construct_endpoint(
201220
tls_server_name.as_deref(),
@@ -220,7 +239,9 @@ impl<'a> Worker<'a> {
220239

221240
let mut all_actions = vec![];
222241

223-
for workflow in &self.workflows {
242+
let data = Arc::new(data);
243+
244+
for workflow in &workflows {
224245
let namespaced_workflow_name =
225246
format!("{namespace}{workflow_name}", workflow_name = workflow.name);
226247

@@ -284,8 +305,8 @@ impl<'a> Worker<'a> {
284305

285306
let request = {
286307
let mut request: tonic::Request<WorkerRegisterRequest> = WorkerRegisterRequest {
287-
worker_name: self.name.clone(),
288-
max_runs: self.max_runs,
308+
worker_name: name,
309+
max_runs,
289310
services: vec!["default".to_owned()],
290311
actions: all_actions,
291312
// FIXME: Implement.
@@ -304,7 +325,7 @@ impl<'a> Worker<'a> {
304325

305326
futures_util::try_join! {
306327
heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver),
307-
listener::run(dispatcher, workflow_service_client, namespace, &worker_id, self.workflows, *listener_v2_timeout, listening_interrupt_receiver, heartbeat_interrupt_sender2),
328+
listener::run(dispatcher, workflow_service_client, namespace, &worker_id, workflows, *listener_v2_timeout, listening_interrupt_receiver, data)
308329
}?;
309330

310331
Ok(())

0 commit comments

Comments
 (0)