@@ -19,22 +19,59 @@ pub struct Step {
1919 pub ( crate ) timeout : std:: time:: Duration ,
2020}
2121
22- impl StepBuilder {
23- pub fn function < I , O , Fut , F > ( mut self , function : & ' static F ) -> Self
24- where
25- I : serde:: de:: DeserializeOwned ,
26- O : serde:: ser:: Serialize ,
27- Fut : std:: future:: Future < Output = anyhow:: Result < O > > + ' static ,
28- F : Fn ( Context , I ) -> Fut ,
29- {
22+ pub trait UserStepFunction < I , O , H > {
23+ fn to_step_function ( self ) -> Arc < StepFunction > ;
24+ }
25+
26+ pub struct NoArguments ;
27+ pub struct ContextArgument ;
28+
29+ impl < I , O , Fut , F > UserStepFunction < I , O , ContextArgument > for & ' static F
30+ where
31+ I : serde:: de:: DeserializeOwned ,
32+ O : serde:: ser:: Serialize ,
33+ Fut : std:: future:: Future < Output = anyhow:: Result < O > > + ' static ,
34+ F : Fn ( Context , I ) -> Fut ,
35+ {
36+ fn to_step_function ( self ) -> Arc < StepFunction > {
3037 use futures_util:: FutureExt ;
31- self . function = Some ( Arc :: new ( |context, value| {
32- let result = function (
38+ Arc :: new ( |context, value| {
39+ let result = ( self ) (
3340 context,
3441 serde_json:: from_value ( value) . expect ( "must succeed" ) ,
3542 ) ;
3643 async { Ok ( serde_json:: to_value ( result. await ?) . expect ( "must succeed" ) ) } . boxed_local ( )
37- } ) ) ;
44+ } )
45+ }
46+ }
47+
48+ impl < I , O , Fut , F > UserStepFunction < I , O , NoArguments > for & ' static F
49+ where
50+ I : serde:: de:: DeserializeOwned ,
51+ O : serde:: ser:: Serialize ,
52+ Fut : std:: future:: Future < Output = anyhow:: Result < O > > + ' static ,
53+ F : Fn ( I ) -> Fut ,
54+ {
55+ fn to_step_function ( self ) -> Arc < StepFunction > {
56+ use futures_util:: FutureExt ;
57+ Arc :: new ( |_context, value| {
58+ let result = ( self ) ( serde_json:: from_value ( value) . expect ( "must succeed" ) ) ;
59+ async { Ok ( serde_json:: to_value ( result. await ?) . expect ( "must succeed" ) ) } . boxed_local ( )
60+ } )
61+ }
62+ }
63+
64+ impl StepBuilder {
65+ pub fn function <
66+ AnyVariant ,
67+ I : serde:: de:: DeserializeOwned ,
68+ O : serde:: ser:: Serialize ,
69+ F : UserStepFunction < I , O , AnyVariant > ,
70+ > (
71+ mut self ,
72+ function : F ,
73+ ) -> Self {
74+ self . function = Some ( function. to_step_function ( ) ) ;
3875 self
3976 }
4077}
0 commit comments