Skip to content

Commit ab1b78a

Browse files
committed
Enable access to the parent step's output value.
1 parent 7daa19b commit ab1b78a

File tree

8 files changed

+308
-69
lines changed

8 files changed

+308
-69
lines changed

Cargo.lock

Lines changed: 80 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ thiserror = "1"
2020
tokio = { version = "1", features = ["fs", "macros", "rt-multi-thread", "sync"] }
2121
tonic = { version = "0.12", features = ["tls", "tls-native-roots"] }
2222
tracing = "0.1"
23+
ustr = { version = "1", features = ["serde"] }
2324

2425
[build-dependencies]
2526
tonic-build = "0.12"

examples/multiple_steps.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder};
2+
3+
#[derive(serde::Serialize, serde::Deserialize)]
4+
struct HelloOutput {
5+
text: String,
6+
}
7+
8+
async fn execute_hello(_context: Context, _: serde_json::Value) -> anyhow::Result<HelloOutput> {
9+
Ok(HelloOutput {
10+
text: "Hello".to_owned(),
11+
})
12+
}
13+
14+
async fn execute_world(
15+
mut context: Context,
16+
_: serde_json::Value,
17+
) -> anyhow::Result<serde_json::Value> {
18+
let hello_result: HelloOutput = context.pop_parent_output("hello");
19+
Ok(serde_json::json!({
20+
"text": format!("{} World!", hello_result.text)
21+
}))
22+
}
23+
24+
#[tokio::main]
25+
async fn main() -> anyhow::Result<()> {
26+
dotenv::dotenv().ok();
27+
tracing_subscriber::fmt()
28+
.with_target(false)
29+
.with_env_filter(
30+
tracing_subscriber::EnvFilter::from_default_env()
31+
.add_directive("hatchet_sdk=debug".parse()?),
32+
)
33+
.init();
34+
35+
let client = Client::new()?;
36+
let mut worker = client.worker("example_spawn_workflow").build();
37+
worker.register_workflow(
38+
WorkflowBuilder::default()
39+
.name("hello-world")
40+
.step(
41+
StepBuilder::default()
42+
.name("hello")
43+
.function(&execute_hello)
44+
.build()?,
45+
)
46+
.step(
47+
StepBuilder::default()
48+
.name("world")
49+
.function(&execute_world)
50+
.parent("hello")
51+
.build()?,
52+
)
53+
.build()?,
54+
);
55+
worker.start().await?;
56+
Ok(())
57+
}

examples/panicking_step.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder};
2+
3+
#[derive(serde::Serialize, serde::Deserialize)]
4+
struct HelloOutput {
5+
text: String,
6+
}
7+
8+
async fn execute_hello(_context: Context, _: serde_json::Value) -> anyhow::Result<HelloOutput> {
9+
Ok(HelloOutput {
10+
text: "Hello".to_owned(),
11+
})
12+
}
13+
14+
async fn execute_panic(
15+
mut context: Context,
16+
_: serde_json::Value,
17+
) -> anyhow::Result<serde_json::Value> {
18+
let hello_result: HelloOutput = context.pop_parent_output("hello");
19+
panic!("Panic {}", hello_result.text);
20+
}
21+
22+
#[tokio::main]
23+
async fn main() -> anyhow::Result<()> {
24+
dotenv::dotenv().ok();
25+
tracing_subscriber::fmt()
26+
.with_target(false)
27+
.with_env_filter(
28+
tracing_subscriber::EnvFilter::from_default_env()
29+
.add_directive("hatchet_sdk=debug".parse()?),
30+
)
31+
.init();
32+
33+
let client = Client::new()?;
34+
let mut worker = client.worker("example_spawn_workflow").build();
35+
worker.register_workflow(
36+
WorkflowBuilder::default()
37+
.name("hello-panic")
38+
.step(
39+
StepBuilder::default()
40+
.name("hello")
41+
.function(&execute_hello)
42+
.build()?,
43+
)
44+
.step(
45+
StepBuilder::default()
46+
.name("panic")
47+
.function(&execute_panic)
48+
.parent("hello")
49+
.build()?,
50+
)
51+
.build()?,
52+
);
53+
worker.start().await?;
54+
Ok(())
55+
}

src/step_function.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66

77
use futures_util::lock::Mutex;
88
use tracing::info;
9+
use ustr::Ustr;
910

1011
use crate::worker::{grpc, ServiceWithAuthorization};
1112

@@ -24,10 +25,12 @@ pub struct Context {
2425
u16,
2526
)>,
2627
data: Arc<DataMap>,
28+
parents_outputs: HashMap<Ustr, serde_json::Value>,
2729
}
2830

2931
impl Context {
3032
pub(crate) fn new(
33+
parents_outputs: HashMap<Ustr, serde_json::Value>,
3134
workflow_run_id: String,
3235
workflow_step_run_id: String,
3336
workflow_service_client: grpc::workflow_service_client::WorkflowServiceClient<
@@ -39,13 +42,23 @@ impl Context {
3942
data: Arc<DataMap>,
4043
) -> Self {
4144
Self {
45+
parents_outputs,
4246
workflow_run_id,
4347
workflow_service_client_and_spawn_index: Mutex::new((workflow_service_client, 0)),
4448
workflow_step_run_id,
4549
data,
4650
}
4751
}
4852

53+
pub fn pop_parent_output<O: serde::de::DeserializeOwned>(&mut self, step_name: &str) -> O {
54+
serde_json::from_value(
55+
self.parents_outputs
56+
.remove(&step_name.into())
57+
.unwrap_or_else(|| panic!("could not find the output for step '{step_name}'")),
58+
)
59+
.expect("could not deserialize from JSON")
60+
}
61+
4962
pub fn datum<D: std::any::Any + Send + Sync>(&self) -> &D {
5063
let type_id = TypeId::of::<D>();
5164
self.data
@@ -87,8 +100,10 @@ impl Context {
87100
}
88101
}
89102

90-
pub(crate) type StepFunction =
91-
dyn Fn(
103+
pub(crate) type StepFunction = dyn Fn(
92104
Context,
93105
serde_json::Value,
94-
) -> futures_util::future::LocalBoxFuture<'static, anyhow::Result<serde_json::Value>>;
106+
) -> std::panic::AssertUnwindSafe<
107+
futures_util::future::BoxFuture<'static, anyhow::Result<serde_json::Value>>,
108+
> + Send
109+
+ Sync;

0 commit comments

Comments
 (0)