Skip to content

Commit 045d07d

Browse files
committed
fix: cancellation of synchronous part of previous elaboration (#7882)
This PR fixes a regression where elaboration of a previous document version is not cancelled on changes to the document. Done by removing the default from `SnapshotTask.cancelTk?` and consistently passing the current thread's token for synchronous elaboration steps. (cherry picked from commit 1421b61)
1 parent aa6e135 commit 045d07d

File tree

11 files changed

+237
-29
lines changed

11 files changed

+237
-29
lines changed

src/Lean/Elab/Command.lean

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,15 @@ where go := do
510510
let oldCmds? := oldSnap?.map fun old =>
511511
if old.newStx.isOfKind nullKind then old.newStx.getArgs else #[old.newStx]
512512
let cmdPromises ← cmds.mapM fun _ => IO.Promise.new
513+
let cancelTk? := (← read).cancelTk?
513514
snap.new.resolve <| .ofTyped {
514515
diagnostics := .empty
515516
macroDecl := decl
516517
newStx := stxNew
517518
newNextMacroScope := nextMacroScope
518519
hasTraces
519520
next := Array.zipWith (fun cmdPromise cmd =>
520-
{ stx? := some cmd, task := cmdPromise.resultD default }) cmdPromises cmds
521+
{ stx? := some cmd, task := cmdPromise.resultD default, cancelTk? }) cmdPromises cmds
521522
: MacroExpandedSnapshot
522523
}
523524
-- After the first command whose syntax tree changed, we must disable

src/Lean/Elab/MutualDef.lean

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
165165
-- no syntax guard to store, we already did the necessary checks
166166
oldBodySnap? := guard reuseBody *> pure ⟨.missing, old.bodySnap⟩
167167
if oldBodySnap?.isNone then
168+
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
169+
-- importantly kernel checking and compilation of the top-level declaration
168170
old.bodySnap.cancelRec
169171
oldTacSnap? := do
170172
guard reuseTac
@@ -217,6 +219,7 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
217219
return newHeader
218220
if let some snap := view.headerSnap? then
219221
let (tacStx?, newTacTask?) ← mkTacTask view.value tacPromise
222+
let cancelTk? := (← readThe Core.Context).cancelTk?
220223
let bodySnap := {
221224
stx? := view.value
222225
reportingRange? :=
@@ -227,6 +230,8 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
227230
else
228231
getBodyTerm? view.value |>.getD view.value |>.getRange?
229232
task := bodyPromise.resultD default
233+
-- We should not cancel the entire body early if we have tactics
234+
cancelTk? := guard newTacTask?.isNone *> cancelTk?
230235
}
231236
snap.new.resolve <| some {
232237
diagnostics :=
@@ -269,7 +274,8 @@ where
269274
:= do
270275
if let some e := getBodyTerm? body then
271276
if let `(by $tacs*) := e then
272-
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default })
277+
let cancelTk? := (← readThe Core.Context).cancelTk?
278+
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default, cancelTk? })
273279
tacPromise.resolve default
274280
return (none, none)
275281

@@ -432,8 +438,7 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
432438
snap.new.resolve <| some old
433439
reusableResult? := some (old.value, old.state)
434440
else
435-
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
436-
-- importantly kernel checking and compilation of the top-level declaration
441+
-- make sure to cancel any async tasks that may still be running (e.g. kernel and codegen)
437442
old.val.cancelRec
438443

439444
let (val, state) ← withRestoreOrSaveFull reusableResult? header.tacSnap? do
@@ -1197,6 +1202,7 @@ private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
11971202
-- Use first line of the mutual block to avoid covering the progress of the whole mutual block
11981203
reportingRange? := (← getRef).getPos?.map fun pos => ⟨pos, pos⟩
11991204
task := logGoalsAccomplishedTask
1205+
cancelTk? := none
12001206
}
12011207

12021208
end Term
@@ -1235,9 +1241,10 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
12351241
} }
12361242
if snap.old?.isSome && (view.headerSnap?.bind (·.old?)).isNone then
12371243
snap.old?.forM (·.val.cancelRec)
1244+
let cancelTk? := (← read).cancelTk?
12381245
defs := defs.push {
12391246
fullHeaderRef
1240-
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default }
1247+
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default, cancelTk? }
12411248
}
12421249
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
12431250
views := views.push view

src/Lean/Elab/Tactic/Basic.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ structure EvalTacticFailure where
165165
state : SavedState
166166

167167
partial def evalTactic (stx : Syntax) : TacticM Unit := do
168+
checkSystem "tactic execution"
168169
profileitM Exception "tactic execution" (decl := stx.getKind) (← getOptions) <|
169170
withRef stx <| withIncRecDepth <| withFreshMacroScope <| match stx with
170171
| .node _ k _ =>
@@ -240,6 +241,7 @@ where
240241
snap.old?.forM (·.val.cancelRec)
241242
let promise ← IO.Promise.new
242243
-- Store new unfolding in the snapshot tree
244+
let cancelTk? := (← readThe Core.Context).cancelTk?
243245
snap.new.resolve {
244246
stx := stx'
245247
diagnostics := .empty
@@ -249,7 +251,7 @@ where
249251
state? := (← Tactic.saveState)
250252
moreSnaps := #[]
251253
}
252-
next := #[{ stx? := stx', task := promise.resultD default }]
254+
next := #[{ stx? := stx', task := promise.resultD default, cancelTk? }]
253255
}
254256
-- Update `tacSnap?` to old unfolding
255257
withTheReader Term.Context ({ · with tacSnap? := some {

src/Lean/Elab/Tactic/BuiltinTactic.lean

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,14 @@ where
7878
let next ← IO.Promise.new
7979
let finished ← IO.Promise.new
8080
let inner ← IO.Promise.new
81+
let cancelTk? := (← readThe Core.Context).cancelTk?
8182
snap.new.resolve {
8283
desc := tac.getKind.toString
8384
diagnostics := .empty
8485
stx := tac
85-
inner? := some { stx? := tac, task := inner.resultD default }
86-
finished := { stx? := tac, task := finished.resultD default }
87-
next := #[{ stx? := stxs, task := next.resultD default }]
86+
inner? := some { stx? := tac, task := inner.resultD default, cancelTk? }
87+
finished := { stx? := tac, task := finished.resultD default, cancelTk? }
88+
next := #[{ stx? := stxs, task := next.resultD default, cancelTk? }]
8889
}
8990
-- Run `tac` in a fresh info tree state and store resulting state in snapshot for
9091
-- incremental reporting, then add back saved trees. Here we rely on `evalTactic`

src/Lean/Elab/Tactic/Induction.lean

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,15 @@ where
286286
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
287287
let finished ← IO.Promise.new
288288
let altPromises ← altStxs.mapM fun _ => IO.Promise.new
289+
let cancelTk? := (← readThe Core.Context).cancelTk?
289290
tacSnap.new.resolve {
290291
-- save all relevant syntax here for comparison with next document version
291292
stx := mkNullNode altStxs
292293
diagnostics := .empty
293294
inner? := none
294-
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default }
295+
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default, cancelTk? }
295296
next := Array.zipWith
296-
(fun stx prom => { stx? := some stx, task := prom.resultD default })
297+
(fun stx prom => { stx? := some stx, task := prom.resultD default, cancelTk? })
297298
altStxs altPromises
298299
}
299300
goWithIncremental <| altPromises.mapIdx fun i prom => {

src/Lean/Language/Basic.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,17 @@ structure SnapshotTask (α : Type) where
9393
Cancellation token that can be set by the server to cancel the task when it detects the results
9494
are not needed anymore.
9595
-/
96-
cancelTk? : Option IO.CancelToken := none
96+
cancelTk? : Option IO.CancelToken
9797
/-- Underlying task producing the snapshot. -/
9898
task : Task α
9999
deriving Nonempty, Inhabited
100100

101101
/-- Creates a snapshot task from the syntax processed by the task and a `BaseIO` action. -/
102-
def SnapshotTask.ofIO (stx? : Option Syntax)
102+
def SnapshotTask.ofIO (stx? : Option Syntax) (cancelTk? : Option IO.CancelToken)
103103
(reportingRange? : Option String.Range := defaultReportingRange? stx?) (act : BaseIO α) :
104104
BaseIO (SnapshotTask α) := do
105105
return {
106-
stx?
107-
reportingRange?
106+
stx?, reportingRange?, cancelTk?
108107
task := (← BaseIO.asTask act)
109108
}
110109

@@ -114,6 +113,7 @@ def SnapshotTask.finished (stx? : Option Syntax) (a : α) : SnapshotTask α wher
114113
-- irrelevant when already finished
115114
reportingRange? := none
116115
task := .pure a
116+
cancelTk? := none
117117

118118
/-- Transforms a task's output without changing the processed syntax. -/
119119
def SnapshotTask.map (t : SnapshotTask α) (f : α → β) (stx? : Option Syntax := t.stx?)

src/Lean/Language/Lean.lean

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ where
397397
diagnostics := oldProcessed.diagnostics
398398
result? := some {
399399
cmdState := oldProcSuccess.cmdState
400-
firstCmdSnap := { stx? := none, task := prom.result! } } }
400+
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk } } }
401401
else
402402
return .finished newStx oldProcessed) } }
403403
else return old
@@ -450,7 +450,7 @@ where
450450
processHeader (stx : Syntax) (parserState : Parser.ModuleParserState) :
451451
LeanProcessingM (SnapshotTask HeaderProcessedSnapshot) := do
452452
let ctx ← read
453-
SnapshotTask.ofIO stx (some ⟨0, ctx.input.endPos⟩) <|
453+
SnapshotTask.ofIO stx none (some ⟨0, ctx.input.endPos⟩) <|
454454
ReaderT.run (r := ctx) <| -- re-enter reader in new task
455455
withHeaderExceptions (α := HeaderProcessedSnapshot) ({ · with result? := none }) do
456456
let setup ← match (← setupImports stx) with
@@ -507,7 +507,7 @@ where
507507
infoTree? := cmdState.infoState.trees[0]!
508508
result? := some {
509509
cmdState
510-
firstCmdSnap := { stx? := none, task := prom.result! }
510+
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk }
511511
}
512512
}
513513

@@ -523,17 +523,19 @@ where
523523
-- from `old`
524524
if let some oldNext := old.nextCmdSnap? then do
525525
let newProm ← IO.Promise.new
526+
let cancelTk ← IO.CancelToken.new
526527
-- can reuse range, syntax unchanged
527528
BaseIO.chainTask (sync := true) old.resultSnap.task fun oldResult =>
528529
-- also wait on old command parse snapshot as parsing is cheap and may allow for
529530
-- elaboration reuse
530531
BaseIO.chainTask (sync := true) oldNext.task fun oldNext => do
531-
let cancelTk ← IO.CancelToken.new
532532
parseCmd oldNext newParserState oldResult.cmdState newProm sync cancelTk ctx
533533
prom.resolve <| { old with nextCmdSnap? := some {
534534
stx? := none
535535
reportingRange? := some ⟨newParserState.pos, ctx.input.endPos⟩
536-
task := newProm.result! } }
536+
task := newProm.result!
537+
cancelTk? := cancelTk
538+
} }
537539
else prom.resolve old -- terminal command, we're done!
538540

539541
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
@@ -615,15 +617,16 @@ where
615617
})
616618
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
617619

618-
-- use per-command cancellation token for elaboration so that
620+
-- use per-command cancellation token for elaboration so that cancellation of further commands
621+
-- does not affect current command
619622
let elabCmdCancelTk ← IO.CancelToken.new
620623
prom.resolve {
621624
diagnostics, nextCmdSnap?
622625
stx := stx', parserState := parserState'
623626
elabSnap := { stx? := stx', task := elabPromise.result!, cancelTk? := some elabCmdCancelTk }
624-
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result! }
625-
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result! }
626-
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result! }
627+
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result!, cancelTk? := none }
628+
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result!, cancelTk? := none }
629+
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result!, cancelTk? := none }
627630
}
628631
let cmdState ← doElab stx cmdState beginPos
629632
{ old? := old?.map fun old => ⟨old.stx, old.elabSnap⟩, new := elabPromise }
@@ -665,8 +668,8 @@ where
665668
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
666669
-- create a temporary snapshot tree containing all tasks but it
667670
let snaps := #[
668-
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree },
669-
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree }] ++
671+
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none },
672+
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none }] ++
670673
cmdState.snapshotTasks
671674
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
672675
BaseIO.bindTask (← tree.waitAll) fun _ => do
@@ -690,6 +693,7 @@ where
690693
stx? := none
691694
reportingRange? := initRange?
692695
task := traceTask
696+
cancelTk? := none
693697
}
694698
if let some next := next? then
695699
-- We're definitely off the fast-forwarding path now

src/Lean/Meta/Basic.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,13 +2279,15 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
22792279
initHeartbeats := (← IO.getNumHeartbeats)
22802280
}
22812281
let (env, exTask, dyn) ← env.realizeConst forConst constName (realizeAndReport coreCtx)
2282+
-- Realizations cannot be cancelled as their result is shared across elaboration runs
22822283
let exAct ← Core.wrapAsyncAsSnapshot (cancelTk? := none) fun
22832284
| none => return
22842285
| some ex => do
22852286
logError <| ex.toMessageData (← getOptions)
22862287
Core.logSnapshotTask {
22872288
stx? := none
22882289
task := (← BaseIO.mapTask (t := exTask) exAct)
2290+
cancelTk? := none
22892291
}
22902292
if let some res := dyn.get? RealizeConstantResult then
22912293
let mut snap := res.snap

src/Lean/Server/Test/Cancel.lean

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,36 @@ elab_rules : tactic
5656
-- can't use a naked promise in `initialize` as marking it persistent would block
5757
initialize unblockedCancelTk : IO.CancelToken ← IO.CancelToken.new
5858

59+
/--
60+
Waits for `unblock` to be called, which is expected to happen in a subsequent document version that
61+
does not invalidate this tactic. Complains if cancellation token was set before unblocking, i.e. if
62+
the tactic was invalidated after all.
63+
-/
64+
scoped syntax "wait_for_unblock" : tactic
65+
@[incremental]
66+
elab_rules : tactic
67+
| `(tactic| wait_for_unblock) => do
68+
let ctx ← readThe Core.Context
69+
let some cancelTk := ctx.cancelTk? | unreachable!
70+
71+
dbg_trace "blocked!"
72+
log "blocked"
73+
let ctx ← readThe Elab.Term.Context
74+
let some tacSnap := ctx.tacSnap? | unreachable!
75+
tacSnap.new.resolve {
76+
diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getMessageLog))
77+
stx := default
78+
finished := default
79+
}
80+
81+
while true do
82+
if (← unblockedCancelTk.isSet) then
83+
break
84+
IO.sleep 30
85+
if (← cancelTk.isSet) then
86+
IO.eprintln "cancelled!"
87+
log "cancelled (should never be visible)"
88+
5989
/--
6090
Spawns a `logSnapshotTask` that waits for `unblock` to be called, which is expected to happen in a
6191
subsequent document version that does not invalidate this tactic. Complains if cancellation token
@@ -83,6 +113,10 @@ scoped elab "unblock" : tactic => do
83113
dbg_trace "unblocking!"
84114
unblockedCancelTk.set
85115

116+
/--
117+
Like `wait_for_cancel_once` but does the waiting in a separate task and waits for its
118+
cancellation.
119+
-/
86120
scoped syntax "wait_for_cancel_once_async" : tactic
87121
@[incremental]
88122
elab_rules : tactic
@@ -110,3 +144,35 @@ elab_rules : tactic
110144

111145
dbg_trace "blocked!"
112146
log "blocked"
147+
148+
/--
149+
Like `wait_for_cancel_once_async` but waits for the main thread's cancellation token. This is useful
150+
to test main thread cancellation in non-incremental contexts because we otherwise wouldn't be able
151+
to send out the "blocked" message from there.
152+
-/
153+
scoped syntax "wait_for_main_cancel_once_async" : tactic
154+
@[incremental]
155+
elab_rules : tactic
156+
| `(tactic| wait_for_main_cancel_once_async) => do
157+
let prom ← IO.Promise.new
158+
if let some t := (← onceRef.modifyGet (fun old => (old, old.getD prom.result!))) then
159+
IO.wait t
160+
return
161+
162+
let some cancelTk := (← readThe Core.Context).cancelTk? | unreachable!
163+
let act ← Elab.Term.wrapAsyncAsSnapshot (cancelTk? := none) fun _ => do
164+
let ctx ← readThe Core.Context
165+
-- TODO: `CancelToken` should probably use `Promise`
166+
while true do
167+
if (← cancelTk.isSet) then
168+
break
169+
IO.sleep 30
170+
IO.eprintln "cancelled!"
171+
log "cancelled (should never be visible)"
172+
prom.resolve ()
173+
Core.checkInterrupted
174+
let t ← BaseIO.asTask (act ())
175+
Core.logSnapshotTask { stx? := none, task := t, cancelTk? := cancelTk }
176+
177+
dbg_trace "blocked!"
178+
log "blocked"

0 commit comments

Comments
 (0)