@@ -88,6 +88,7 @@ class InstanceType(ExternType):
8888class FuncType (ExternType ):
8989 params : list [tuple [str ,ValType ]]
9090 result : list [ValType | tuple [str ,ValType ]]
91+ async_ : bool = False
9192 def param_types (self ):
9293 return self .extract_types (self .params )
9394 def result_type (self ):
@@ -566,8 +567,13 @@ def trap_if_on_the_stack(self, inst):
566567 def needs_exclusive (self ):
567568 return not self .opts .async_ or self .opts .callback
568569
570+ def must_not_suspend (self ):
571+ return not self .ft .async_ and self .state != Task .State .RESOLVED
572+
569573 def enter (self , thread ):
570574 assert (thread in self .threads and thread .task is self )
575+ if not self .ft .async_ :
576+ return True
571577 def has_backpressure ():
572578 return self .inst .backpressure > 0 or (self .needs_exclusive () and self .inst .exclusive )
573579 if has_backpressure () or self .inst .num_waiting_to_enter > 0 :
@@ -584,6 +590,8 @@ def has_backpressure():
584590
585591 def exit (self ):
586592 assert (len (self .threads ) > 0 )
593+ if not self .ft .async_ :
594+ return
587595 if self .needs_exclusive ():
588596 assert (self .inst .exclusive )
589597 self .inst .exclusive = False
@@ -2023,12 +2031,17 @@ def thread_func(thread):
20232031 inst .exclusive = False
20242032 match code :
20252033 case CallbackCode .YIELD :
2026- event = task .yield_until (lambda : not inst .exclusive , thread , cancellable = True )
2034+ if thread .task .must_not_suspend ():
2035+ event = (EventCode .NONE , 0 , 0 )
2036+ else :
2037+ event = task .yield_until (lambda : not inst .exclusive , thread , cancellable = True )
20272038 case CallbackCode .WAIT :
2039+ trap_if (thread .task .must_not_suspend ())
20282040 wset = inst .table .get (si )
20292041 trap_if (not isinstance (wset , WaitableSet ))
20302042 event = task .wait_until (lambda : not inst .exclusive , thread , wset , cancellable = True )
20312043 case CallbackCode .POLL :
2044+ trap_if (thread .task .must_not_suspend ())
20322045 wset = inst .table .get (si )
20332046 trap_if (not isinstance (wset , WaitableSet ))
20342047 event = task .poll_until (lambda : not inst .exclusive , thread , wset , cancellable = True )
@@ -2069,6 +2082,7 @@ def call_and_trap_on_throw(callee, thread, args):
20692082
20702083def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
20712084 trap_if (not thread .task .inst .may_leave )
2085+ trap_if (ft .async_ and not opts .async_ and thread .task .must_not_suspend ())
20722086 subtask = Subtask ()
20732087 cx = LiftLowerContext (opts , thread .task .inst , subtask )
20742088
@@ -2108,6 +2122,7 @@ def on_resolve(result):
21082122 flat_results = lower_flat_values (cx , max_flat_results , result , ft .result_type (), flat_args )
21092123
21102124 subtask .callee = callee (thread .task , on_start , on_resolve )
2125+ assert (ft .async_ or subtask .resolved ())
21112126
21122127 if not opts .async_ :
21132128 if not subtask .resolved ():
@@ -2142,31 +2157,30 @@ def canon_resource_new(rt, thread, rep):
21422157
21432158### `canon resource.drop`
21442159
2145- def canon_resource_drop (rt , async_ , thread , i ):
2160+ def canon_resource_drop (rt , thread , i ):
21462161 trap_if (not thread .task .inst .may_leave )
21472162 inst = thread .task .inst
21482163 h = inst .table .remove (i )
21492164 trap_if (not isinstance (h , ResourceHandle ))
21502165 trap_if (h .rt is not rt )
21512166 trap_if (h .num_lends != 0 )
2152- flat_results = [] if not async_ else [0 ]
21532167 if h .own :
21542168 assert (h .borrow_scope is None )
21552169 if inst is rt .impl :
21562170 if rt .dtor :
21572171 rt .dtor (h .rep )
21582172 else :
21592173 if rt .dtor :
2160- caller_opts = CanonicalOptions (async_ = async_ )
2174+ caller_opts = CanonicalOptions (async_ = False )
21612175 callee_opts = CanonicalOptions (async_ = rt .dtor_async , callback = rt .dtor_callback )
2162- ft = FuncType ([U32Type ()],[])
2176+ ft = FuncType ([U32Type ()],[], async_ = False )
21632177 callee = partial (canon_lift , callee_opts , rt .impl , ft , rt .dtor )
2164- flat_results = canon_lower (caller_opts , ft , callee , thread , [h .rep ])
2178+ [] = canon_lower (caller_opts , ft , callee , thread , [h .rep ])
21652179 else :
21662180 thread .task .trap_if_on_the_stack (rt .impl )
21672181 else :
21682182 h .borrow_scope .num_borrows -= 1
2169- return flat_results
2183+ return []
21702184
21712185### `canon resource.rep`
21722186
@@ -2244,6 +2258,7 @@ def canon_waitable_set_new(thread):
22442258
22452259def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
22462260 trap_if (not thread .task .inst .may_leave )
2261+ trap_if (thread .task .must_not_suspend ())
22472262 wset = thread .task .inst .table .get (si )
22482263 trap_if (not isinstance (wset , WaitableSet ))
22492264 event = thread .task .wait_until (lambda : True , thread , wset , cancellable )
@@ -2260,6 +2275,7 @@ def unpack_event(mem, thread, ptr, e: EventTuple):
22602275
22612276def canon_waitable_set_poll (cancellable , mem , thread , si , ptr ):
22622277 trap_if (not thread .task .inst .may_leave )
2278+ trap_if (thread .task .must_not_suspend ())
22632279 wset = thread .task .inst .table .get (si )
22642280 trap_if (not isinstance (wset , WaitableSet ))
22652281 event = thread .task .poll_until (lambda : True , thread , wset , cancellable )
@@ -2294,6 +2310,7 @@ def canon_waitable_join(thread, wi, si):
22942310
22952311def canon_subtask_cancel (async_ , thread , i ):
22962312 trap_if (not thread .task .inst .may_leave )
2313+ trap_if (not async_ and thread .task .must_not_suspend ())
22972314 subtask = thread .task .inst .table .get (i )
22982315 trap_if (not isinstance (subtask , Subtask ))
22992316 trap_if (subtask .resolve_delivered ())
@@ -2350,6 +2367,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23502367
23512368def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
23522369 trap_if (not thread .task .inst .may_leave )
2370+ trap_if (not opts .async_ and thread .task .must_not_suspend ())
23532371 e = thread .task .inst .table .get (i )
23542372 trap_if (not isinstance (e , EndT ))
23552373 trap_if (e .shared .t != stream_t .t )
@@ -2401,6 +2419,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24012419
24022420def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24032421 trap_if (not thread .task .inst .may_leave )
2422+ trap_if (not opts .async_ and thread .task .must_not_suspend ())
24042423 e = thread .task .inst .table .get (i )
24052424 trap_if (not isinstance (e , EndT ))
24062425 trap_if (e .shared .t != future_t .t )
@@ -2451,6 +2470,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24512470
24522471def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
24532472 trap_if (not thread .task .inst .may_leave )
2473+ trap_if (not async_ and thread .task .must_not_suspend ())
24542474 e = thread .task .inst .table .get (i )
24552475 trap_if (not isinstance (e , EndT ))
24562476 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2527,6 +2547,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25272547
25282548def canon_thread_suspend (cancellable , thread ):
25292549 trap_if (not thread .task .inst .may_leave )
2550+ trap_if (thread .task .must_not_suspend ())
25302551 suspend_result = thread .task .suspend (thread , cancellable )
25312552 return [suspend_result ]
25322553
@@ -2554,6 +2575,8 @@ def canon_thread_yield_to(cancellable, thread, i):
25542575
25552576def canon_thread_yield (cancellable , thread ):
25562577 trap_if (not thread .task .inst .may_leave )
2578+ if thread .task .must_not_suspend ():
2579+ return [SuspsendResult .COMPLETED ]
25572580 event_code ,_ ,_ = thread .task .yield_until (lambda : True , thread , cancellable )
25582581 match event_code :
25592582 case EventCode .NONE :
0 commit comments