Skip to content

Commit 8f344bd

Browse files
feat: Retain is_running=True during cancellation (#235)
Co-authored-by: Cursor Agent <[email protected]>
1 parent 11e39a0 commit 8f344bd

File tree

5 files changed

+65
-4
lines changed

5 files changed

+65
-4
lines changed

.changeset/wide-islands-start.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"llama-index-workflows": patch
3+
---
4+
5+
Fix resuming from serialized context for workflows that uses typed events

packages/llama-index-workflows/src/workflows/runtime/control_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,9 @@ def _process_cancel_run_tick(
648648
tick: TickCancelRun, init: BrokerState
649649
) -> tuple[BrokerState, list[WorkflowCommand]]:
650650
state = init.deepcopy()
651-
state.is_running = False
651+
# retain running state, for resumption.
652+
# TODO - when/if we persist stream events, this StopEvent should be reconsidered, as there should only ever be one stop event.
653+
# Perhaps on resumption, if the workflow is running, then any existing stop events of a "cancellation" type should be omitted from the stream.
652654
return state, [
653655
CommandPublishEvent(event=StopEvent()),
654656
CommandHalt(exception=WorkflowCancelledByUser()),

packages/llama-index-workflows/src/workflows/runtime/types/internal_state.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def from_serialized(
166166

167167
# Start with a base state from the workflow
168168
base_state = BrokerState.from_workflow(workflow)
169-
# Always set is_running to False on deserialization - the workflow will set it to True when it starts
170-
base_state.is_running = False
169+
# Unfortunately, important to preserve this state, since the workflow needs to know this to decide
170+
# whether to create a start_event from kwargs (it only constructs and passes a start event if not already running)
171+
base_state.is_running = serialized.is_running
171172

172173
# Restore worker state (queues, collected events, waiters)
173174
# We do this regardless of is_running state so workflows can resume from where they left off

packages/llama-index-workflows/tests/runtime/test_control_loop_transformations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,10 @@ def test_cancel_run(base_state: BrokerState) -> None:
421421
tick = TickCancelRun()
422422
new_state, commands = _process_cancel_run_tick(tick, base_state)
423423

424-
assert new_state.is_running is False
424+
# This is perhaps unintuitive, but it's important to be able to cancel and resume a workflow
425+
# based on this state--Workflow uses this as a signal to determine whether to pass or construct
426+
# a start event
427+
assert new_state.is_running is True
425428
assert len(commands) == 2
426429
assert isinstance(commands[0], CommandPublishEvent)
427430
assert isinstance(commands[1], CommandHalt)

packages/llama-index-workflows/tests/test_workflow.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class MyStop(StopEvent):
5555
outcome: str
5656

5757

58+
class ResumeStartEvent(StartEvent):
59+
topic: str
60+
61+
5862
def test_fn() -> None:
5963
print("test_fn")
6064

@@ -557,6 +561,52 @@ async def test_human_in_the_loop_with_resume() -> None:
557561
assert step2_runs == 1
558562

559563

564+
@pytest.mark.asyncio
565+
async def test_human_in_the_loop_resume_custom_start_event_inactive_ctx() -> None:
566+
class CustomHumanWorkflow(Workflow):
567+
@step
568+
async def ask(self, ctx: Context, ev: ResumeStartEvent) -> InputRequiredEvent:
569+
runs = await ctx.store.get("ask_runs", default=0)
570+
await ctx.store.set("ask_runs", runs + 1)
571+
return InputRequiredEvent(prefix=ev.topic) # type: ignore[arg-type]
572+
573+
@step
574+
async def complete(self, ctx: Context, ev: HumanResponseEvent) -> StopEvent:
575+
runs = await ctx.store.get("complete_runs", default=0)
576+
await ctx.store.set("complete_runs", runs + 1)
577+
return StopEvent(result=ev.response)
578+
579+
workflow = CustomHumanWorkflow()
580+
handler: WorkflowHandler = workflow.run(topic="pizza")
581+
assert handler.ctx
582+
583+
async for event in handler.stream_events():
584+
if isinstance(event, InputRequiredEvent):
585+
break
586+
587+
await handler.cancel_run()
588+
ctx_dict = handler.ctx.to_dict()
589+
assert ctx_dict["is_running"]
590+
591+
resumed_ctx = Context.from_dict(workflow, ctx_dict)
592+
resumed_handler = workflow.run(ctx=resumed_ctx)
593+
resumed_handler.ctx.send_event(HumanResponseEvent(response="42")) # type: ignore[arg-type]
594+
595+
events = []
596+
async for event in resumed_handler.stream_events():
597+
events.append(event)
598+
599+
assert events == [StopEvent(result="42")]
600+
601+
final_result = await resumed_handler
602+
assert final_result == "42"
603+
604+
ask_runs = await resumed_handler.ctx.store.get("ask_runs") # type: ignore[arg-type]
605+
complete_runs = await resumed_handler.ctx.store.get("complete_runs") # type: ignore[arg-type]
606+
assert ask_runs == 1
607+
assert complete_runs == 1
608+
609+
560610
class DummyWorkflowForConcurrentRunsTest(Workflow):
561611
def __init__(self, **kwargs: Any) -> None:
562612
super().__init__(**kwargs)

0 commit comments

Comments
 (0)