99public import Init.Data.String.Pattern.Basic
1010public import Init.Data.Iterators.Internal.Termination
1111public import Init.Data.Iterators.Consumers.Monadic.Loop
12+ public import Init.Data.Vector.Basic
1213
1314set_option doc.verso true
1415
@@ -21,67 +22,87 @@ public section
2122
2223namespace String.Slice.Pattern
2324
24- inductive ForwardSliceSearcher (s : Slice) where
25- | emptyBefore (pos : s.Pos)
26- | emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
27- | proper (needle : Slice) (table : Array String.Pos.Raw) (stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw)
28- | atEnd
29- deriving Inhabited
30-
3125namespace ForwardSliceSearcher
3226
33- partial def buildTable (pat : Slice) : Array String.Pos.Raw :=
34- if pat.utf8ByteSize = = 0 then
35- #[]
27+ def buildTable (pat : Slice) : Vector Nat pat.utf8ByteSize :=
28+ if h : pat.utf8ByteSize = 0 then
29+ #v[].cast h.symm
3630 else
3731 let arr := Array.emptyWithCapacity pat.utf8ByteSize
38- let arr := arr.push 0
39- go ⟨ 1 ⟩ arr
32+ let arr' := arr.push 0
33+ go arr' ( by simp [ arr']) ( by simp [arr', arr]; omega) ( by simp [arr', arr])
4034where
41- go (pos : String.Pos.Raw) (table : Array String.Pos.Raw) :=
42- if h : pos < pat.rawEndPos then
43- let patByte := pat.getUTF8Byte pos h
44- let distance := computeDistance table[table.size - 1 ]! patByte table
45- let distance := if patByte = pat.getUTF8Byte! distance then distance.inc else distance
46- go pos.inc (table.push distance)
35+ go (table : Array Nat) (ht₀ : 0 < table.size) (ht : table.size ≤ pat.utf8ByteSize) (h : ∀ (i : Nat) hi, table[i]'hi ≤ i) :
36+ Vector Nat pat.utf8ByteSize :=
37+ if hs : table.size < pat.utf8ByteSize then
38+ let patByte := pat.getUTF8Byte ⟨table.size⟩ hs
39+ let dist := computeDistance patByte table ht h (table[table.size - 1 ])
40+ (by have := h (table.size - 1 ) (by omega); omega)
41+ let dist' := if pat.getUTF8Byte ⟨dist.1 ⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then dist.1 + 1 else dist
42+ go (table.push dist') (by simp) (by simp; omega) (by
43+ intro i hi
44+ by_cases hi' : i = table.size
45+ · subst hi'
46+ simp [dist']
47+ have := dist.2
48+ split <;> omega
49+ · rw [Array.getElem_push_lt]
50+ · apply h
51+ · simp at hi
52+ omega)
4753 else
48- table
49-
50- computeDistance (distance : String.Pos.Raw) (patByte : UInt8) (table : Array String.Pos.Raw) :
51- String.Pos.Raw :=
52- if distance > 0 && patByte != pat.getUTF8Byte! distance then
53- computeDistance table[distance.byteIdx - 1 ]! patByte table
54+ Vector.mk table (by omega)
55+
56+ computeDistance (patByte : UInt8) (table : Array Nat)
57+ (ht : table.size ≤ pat.utf8ByteSize)
58+ (h : ∀ (i : Nat) hi, table[i]'hi ≤ i) (guess : Nat) (hg : guess < table.size) :
59+ { n : Nat // n < table.size } :=
60+ if h' : guess = 0 ∨ pat.getUTF8Byte ⟨guess⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then
61+ ⟨guess, hg⟩
5462 else
55- distance
63+ have : table[guess - 1 ] < guess := by have := h (guess - 1 ) (by omega); omega
64+ computeDistance patByte table ht h table[guess - 1 ] (by omega)
65+
66+ theorem getElem_buildTable_le (pat : Slice) (i : Nat) (hi) : (buildTable pat)[i]'hi ≤ i := by
67+ rw [buildTable]
68+ split <;> rename_i h
69+ · simp [h] at hi
70+ · simp only [Array.emptyWithCapacity_eq, List.push_toArray, List.nil_append]
71+ suffices ∀ pat' table ht₀ ht h (i : Nat) hi, (buildTable.go pat' table ht₀ ht h)[i]'hi ≤ i from this ..
72+ intro pat' table ht₀ ht h i hi
73+ fun_induction buildTable.go with
74+ | case1 => assumption
75+ | case2 table ht₀ ht ht' ht'' => apply ht'
76+
77+ inductive _root_.String.Slice.Pattern.ForwardSliceSearcher (s : Slice) where
78+ | emptyBefore (pos : s.Pos)
79+ | emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
80+ | proper (needle : Slice) (table : Vector Nat needle.utf8ByteSize) (ht : table = buildTable needle)
81+ (stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (hn : needlePos < needle.rawEndPos)
82+ | atEnd
83+ deriving Inhabited
5684
5785@[inline]
5886def iter (s : Slice) (pat : Slice) : Std.Iter (α := ForwardSliceSearcher s) (SearchStep s) :=
59- if pat.utf8ByteSize = = 0 then
87+ if h : pat.utf8ByteSize = 0 then
6088 { internalState := .emptyBefore s.startPos }
6189 else
62- { internalState := .proper pat (buildTable pat) s.startPos.offset pat.startPos.offset }
63-
64- partial def backtrackIfNecessary (pat : Slice) (table : Array String.Pos.Raw) (stackByte : UInt8)
65- (needlePos : String.Pos.Raw) : String.Pos.Raw :=
66- if needlePos != 0 && stackByte != pat.getUTF8Byte! needlePos then
67- backtrackIfNecessary pat table stackByte table[needlePos.byteIdx - 1 ]!
68- else
69- needlePos
90+ { internalState := .proper pat (buildTable pat) rfl s.startPos.offset pat.startPos.offset
91+ (by simp [Pos.Raw.lt_iff]; omega) }
7092
7193instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (SearchStep s) where
7294 IsPlausibleStep it
73- | .yield it' out =>
74- match it.internalState with
95+ | .yield it' out | .skip it' =>
96+ match it.internalState with
7597 | .emptyBefore pos => (∃ h, it'.internalState = .emptyAt pos h) ∨ it'.internalState = .atEnd
7698 | .emptyAt pos h => ∃ newPos, pos < newPos ∧ it'.internalState = .emptyBefore newPos
77- | .proper needle table stackPos needlePos =>
78- (∃ newStackPos newNeedlePos,
79- stackPos < newStackPos ∧
80- newStackPos ≤ s.rawEndPos ∧
81- it'.internalState = .proper needle table newStackPos newNeedlePos) ∨
99+ | .proper needle table ht stackPos needlePos hn =>
100+ (∃ newStackPos newNeedlePos hn ,
101+ it'.internalState = .proper needle table ht newStackPos newNeedlePos hn ∧
102+ ((s.utf8ByteSize - newStackPos.byteIdx < s.utf8ByteSize - stackPos.byteIdx) ∨
103+ ( newStackPos = stackPos ∧ newNeedlePos < needlePos)) ) ∨
82104 it'.internalState = .atEnd
83105 | .atEnd => False
84- | .skip _ => False
85106 | .done => True
86107 step := fun ⟨iter⟩ =>
87108 match iter with
@@ -94,67 +115,102 @@ instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (Searc
94115 | .emptyAt pos h =>
95116 let res := .rejected pos (pos.next h)
96117 pure (.deflate ⟨.yield ⟨.emptyBefore (pos.next h)⟩ res, by simp⟩)
97- | .proper needle table stackPos needlePos =>
98- let rec findNext (startPos : String.Pos.Raw)
99- (currStackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (h : stackPos ≤ currStackPos) :=
100- if h1 : currStackPos < s.rawEndPos then
101- let stackByte := s.getUTF8Byte currStackPos h1
102- let needlePos := backtrackIfNecessary needle table stackByte needlePos
103- let patByte := needle.getUTF8Byte! needlePos
104- if stackByte != patByte then
105- let nextStackPos := s.findNextPos currStackPos h1 |>.offset
106- let res := .rejected (s.pos! startPos) (s.pos! nextStackPos)
107- have hiter := by
108- left
109- exists nextStackPos
110- have haux := lt_offset_findNextPos h1
111- simp only [String.Pos.Raw.lt_iff, proper.injEq, true_and, exists_and_left, exists_eq', and_true,
112- nextStackPos]
113- constructor
114- · simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h haux ⊢
115- omega
116- · apply Pos.Raw.IsValidForSlice.le_utf8ByteSize
117- apply Pos.isValidForSlice
118- .deflate ⟨.yield ⟨.proper needle table nextStackPos needlePos⟩ res, hiter⟩
118+ | .proper needle table htable stackPos needlePos hn =>
119+ -- **Invariant 1:** we have already covered everything up until `stackPos - needlePos` (exclusive),
120+ -- with matches and rejections.
121+ -- **Invariant 2:** `stackPos - needlePos` is a valid position
122+ -- **Invariant 3:** the range from from `stackPos - needlePos` to `stackPos` (exclusive) is a
123+ -- prefix of the pattern.
124+ if h₁ : stackPos < s.rawEndPos then
125+ let stackByte := s.getUTF8Byte stackPos h₁
126+ let patByte := needle.getUTF8Byte needlePos hn
127+ if stackByte = patByte then
128+ let nextStackPos := stackPos.inc
129+ let nextNeedlePos := needlePos.inc
130+ if h : nextNeedlePos = needle.rawEndPos then
131+ -- Safety: the section from `nextStackPos.decreaseBy needle.utf8ByteSize` to `nextStackPos`
132+ -- (exclusive) is exactly the needle, so it must represent a valid range.
133+ let res := .matched (s.pos! (nextStackPos.decreaseBy needle.utf8ByteSize)) (s.pos! nextStackPos)
134+ -- Invariants still satisfied
135+ pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos 0
136+ (by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
137+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
138+ Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
119139 else
120- let needlePos := needlePos.inc
121- if needlePos == needle.rawEndPos then
122- let nextStackPos := currStackPos.inc
123- let res := .matched (s.pos! startPos) (s.pos! nextStackPos)
124- have hiter := by
125- left
126- exists nextStackPos
127- simp only [Pos.Raw.byteIdx_inc, proper.injEq, true_and, exists_and_left,
128- exists_eq', and_true, nextStackPos, String.Pos.Raw.lt_iff]
129- constructor
130- · simp [String.Pos.Raw.le_iff] at h ⊢
131- omega
132- · simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h1 ⊢
133- omega
134- .deflate ⟨.yield ⟨.proper needle table nextStackPos 0 ⟩ res, hiter⟩
135- else
136- have hinv := by
137- simp [String.Pos.Raw.le_iff] at h ⊢
138- omega
139- findNext startPos currStackPos.inc needlePos hinv
140+ -- Invariants still satisfied
141+ pure (.deflate ⟨.skip ⟨.proper needle table htable nextStackPos nextNeedlePos
142+ (by simp [Pos.Raw.lt_iff, nextNeedlePos, Pos.Raw.ext_iff] at h hn ⊢; omega)⟩,
143+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [nextNeedlePos, Pos.Raw.lt_iff, Pos.Raw.ext_iff] at h hn ⊢; omega,
144+ Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
140145 else
141- if startPos != s.rawEndPos then
142- let res := .rejected (s.pos! startPos) (s.pos! currStackPos)
143- .deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩
146+ if hnp : needlePos.byteIdx = 0 then
147+ -- Safety: by invariant 2
148+ let basePos := s.pos! stackPos
149+ -- Since we report (mis)matches by code point and not by byte, missing in the first byte
150+ -- means that we should skip ahead to the next code point.
151+ let nextStackPos := s.findNextPos stackPos h₁
152+ let res := .rejected basePos nextStackPos
153+ -- Invariants still satisfied
154+ pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
155+ (by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
156+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
157+ Or.inl (by
158+ have := lt_offset_findNextPos h₁
159+ have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
160+ simp [nextStackPos, Pos.Raw.lt_iff, Pos.Raw.le_iff] at this t₀ ⊢; omega)⟩⟩)
144161 else
145- .deflate ⟨.done, by simp⟩
146- termination_by s.utf8ByteSize - currStackPos.byteIdx
147- decreasing_by
148- simp [String.Pos.Raw.lt_iff] at h1 ⊢
149- omega
150-
151- findNext stackPos stackPos needlePos (by simp)
162+ let newNeedlePos := table[needlePos.byteIdx - 1 ]'(by simp [Pos.Raw.lt_iff] at hn; omega)
163+ if newNeedlePos = 0 then
164+ -- Safety: by invariant 2
165+ let basePos := s.pos! (stackPos.unoffsetBy needlePos)
166+ -- Since we report (mis)matches by code point and not by byte, missing in the first byte
167+ -- means that we should skip ahead to the next code point.
168+ let nextStackPos := (s.pos? stackPos).getD (s.findNextPos stackPos h₁)
169+ let res := .rejected basePos nextStackPos
170+ -- Invariants still satisfied
171+ pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
172+ (by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
173+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega, by
174+ simp only [pos?, Pos.Raw.isValidForSlice_eq_true_iff, nextStackPos]
175+ split
176+ · exact Or.inr (by simp [Pos.Raw.lt_iff]; omega)
177+ · refine Or.inl ?_
178+ have := lt_offset_findNextPos h₁
179+ have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
180+ simp [Pos.Raw.lt_iff, Pos.Raw.le_iff] at this t₀ ⊢; omega⟩⟩)
181+ else
182+ let oldBasePos := s.pos! (stackPos.decreaseBy needlePos.byteIdx)
183+ let newBasePos := s.pos! (stackPos.decreaseBy newNeedlePos)
184+ let res := .rejected oldBasePos newBasePos
185+ -- Invariants still satisfied by definition of the prefix table
186+ pure (.deflate ⟨.yield ⟨.proper needle table htable stackPos ⟨newNeedlePos⟩
187+ (by
188+ subst htable
189+ have := getElem_buildTable_le needle (needlePos.byteIdx - 1 ) (by simp [Pos.Raw.lt_iff] at hn; omega)
190+ simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
191+ omega)⟩ res,
192+ by
193+ simp only [proper.injEq, heq_eq_eq, true_and, exists_and_left, exists_prop,
194+ reduceCtorEq, or_false]
195+ refine ⟨_, _, ⟨rfl, rfl⟩, ?_, Or.inr ⟨rfl, ?_⟩⟩
196+ all_goals
197+ subst htable
198+ have := getElem_buildTable_le needle (needlePos.byteIdx - 1 ) (by simp [Pos.Raw.lt_iff] at hn; omega)
199+ simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
200+ omega⟩)
201+ else
202+ if 0 < needlePos then
203+ let basePos := stackPos.unoffsetBy needlePos
204+ let res := .rejected (s.pos! basePos) s.endPos
205+ pure (.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩)
206+ else
207+ pure (.deflate ⟨.done, by simp⟩)
152208 | .atEnd => pure (.deflate ⟨.done, by simp⟩)
153209
154210private def toOption : ForwardSliceSearcher s → Option (Nat × Nat)
155211 | .emptyBefore pos => some (s.utf8ByteSize - pos.offset.byteIdx, 1 )
156212 | .emptyAt pos _ => some (s.utf8ByteSize - pos.offset.byteIdx, 0 )
157- | .proper _ _ sp _ => some (s.utf8ByteSize - sp.byteIdx, 0 )
213+ | .proper _ _ _ sp np _ => some (s.utf8ByteSize - sp.byteIdx, np.byteIdx )
158214 | .atEnd => none
159215
160216private instance : WellFoundedRelation (ForwardSliceSearcher s) where
@@ -172,7 +228,8 @@ private def finitenessRelation :
172228 simp_wf
173229 obtain ⟨step, h, h'⟩ := h
174230 cases step
175- · cases h
231+ all_goals try
232+ cases h
176233 revert h'
177234 simp only [Std.Iterators.IterM.IsPlausibleStep, Std.Iterators.Iterator.IsPlausibleStep]
178235 match it.internalState with
@@ -185,21 +242,21 @@ private def finitenessRelation :
185242 simp [h, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.lt_iff,
186243 Pos.Raw.lt_iff, Pos.Raw.le_iff] at hx ⊢ this
187244 omega
188- | .proper needle table stackPos needlePos =>
189- simp only [exists_and_left]
190- rintro (⟨newStackPos, h₁, h₂, ⟨x, hx⟩⟩|h)
191- · simp [hx, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff,
192- Pos.Raw.le_iff] at ⊢ h₁ h₂
193- omega
245+ | .proper .. =>
246+ rintro (⟨newStackPos, newNeedlePos, h₁, h₂, (h|⟨rfl, h⟩)⟩|h)
247+ · simp [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, h]
248+ · simpa [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff]
194249 · simp [h, ForwardSliceSearcher.toOption, Option.lt]
195250 | .atEnd .. => simp
196- · cases h'
197251 · cases h
198252
199253@[no_expose]
200254instance : Std.Iterators.Finite (ForwardSliceSearcher s) Id :=
201255 .of_finitenessRelation finitenessRelation
202256
257+ instance : Std.Iterators.IteratorCollect (ForwardSliceSearcher s) Id Id :=
258+ .defaultImplementation
259+
203260instance : Std.Iterators.IteratorLoop (ForwardSliceSearcher s) Id Id :=
204261 .defaultImplementation
205262
0 commit comments