|
1 | | -import Foundation |
2 | 1 | import NIOCore |
3 | 2 |
|
4 | | -extension AsyncSequence where Element == ByteBuffer { |
5 | | - public func getServerSentEvents(allocator: ByteBufferAllocator) -> AsyncThrowingStream<ServerSentEvent, Error> { |
6 | | - AsyncThrowingStream { continuation in |
7 | | - let task = Task { |
8 | | - var parser = SSEParser() |
9 | | - var text = allocator.buffer(capacity: 1024) |
10 | | - |
11 | | - for try await var buffer in self { |
12 | | - text.writeBuffer(&buffer) |
13 | | - |
14 | | - do { |
15 | | - for event in try parser.process(sse: &text) { |
16 | | - continuation.yield(event) |
17 | | - } |
| 3 | +public struct SSEStream: AsyncSequence { |
| 4 | + public struct AsyncIterator: AsyncIteratorProtocol { |
| 5 | + public typealias Element = ServerSentEvent |
| 6 | + private var bufferedEvents = [ServerSentEvent]() |
| 7 | + private var text: ByteBuffer |
| 8 | + private var parser = SSEParser() |
| 9 | + let produce: () async throws -> ByteBuffer? |
| 10 | + |
| 11 | + init( |
| 12 | + allocator: ByteBufferAllocator, |
| 13 | + produce: @escaping () async throws -> ByteBuffer? |
| 14 | + ) { |
| 15 | + self.text = allocator.buffer(capacity: 1024) |
| 16 | + self.produce = produce |
| 17 | + } |
18 | 18 |
|
19 | | - text.discardReadBytes() |
20 | | - } catch { |
21 | | - continuation.finish(throwing: error) |
22 | | - return |
23 | | - } |
| 19 | + private mutating func _next() async throws -> ServerSentEvent? { |
| 20 | + while bufferedEvents.isEmpty { |
| 21 | + guard var buffer = try await produce() else { |
| 22 | + return nil |
24 | 23 | } |
25 | 24 |
|
26 | | - continuation.finish() |
| 25 | + text.writeBuffer(&buffer) |
| 26 | + try bufferedEvents.append(contentsOf: parser.process(sse: &text)) |
| 27 | + text.discardReadBytes() |
27 | 28 | } |
28 | 29 |
|
29 | | - continuation.onTermination = { reason in |
30 | | - task.cancel() |
| 30 | + if bufferedEvents.isEmpty { |
| 31 | + return nil |
31 | 32 | } |
| 33 | + |
| 34 | + return bufferedEvents.removeFirst() |
32 | 35 | } |
33 | | - } |
34 | | -} |
35 | 36 |
|
36 | | -internal struct SSEParser { |
37 | | - var events = [ServerSentEvent]() |
38 | | - var type = "message" |
39 | | - var data = [String]() |
40 | | - var id: String? |
| 37 | + #if compiler(>=6.0) |
| 38 | + public mutating func next( |
| 39 | + isolation actor: isolated (any Actor)? = #isolation |
| 40 | + ) async throws -> ServerSentEvent? { |
| 41 | + try await _next() |
| 42 | + } |
| 43 | + #endif |
41 | 44 |
|
42 | | - enum ParsingStatus { |
43 | | - case nextField, haltParsing |
| 45 | + public mutating func next() async throws -> ServerSentEvent? { |
| 46 | + try await _next() |
| 47 | + } |
44 | 48 | } |
45 | 49 |
|
46 | | - init() {} |
| 50 | + private let iterator: AsyncIterator |
47 | 51 |
|
48 | | - mutating func reset() { |
49 | | - self = SSEParser() |
| 52 | + @_disfavoredOverload |
| 53 | + internal init<Sequence: AsyncSequence>( |
| 54 | + sequence: Sequence, |
| 55 | + allocator: ByteBufferAllocator |
| 56 | + ) where Sequence.Element == ByteBuffer { |
| 57 | + var iterator = sequence.makeAsyncIterator() |
| 58 | + self.iterator = AsyncIterator(allocator: allocator) { |
| 59 | + try await iterator.next() |
| 60 | + } |
50 | 61 | } |
51 | 62 |
|
52 | | - mutating func getEvents() -> [ServerSentEvent] { |
53 | | - let events = events |
54 | | - self.reset() |
55 | | - return events |
| 63 | + public func makeAsyncIterator() -> AsyncIterator { |
| 64 | + iterator |
56 | 65 | } |
| 66 | +} |
57 | 67 |
|
58 | | - mutating func process(sse text: inout ByteBuffer) throws -> [ServerSentEvent] { |
59 | | - func checkEndOfEventAndStream() -> ParsingStatus { |
60 | | - guard let nextCharacter: UInt8 = text.getInteger(at: text.readerIndex) else { |
61 | | - return .haltParsing |
62 | | - } |
63 | | - |
64 | | - // Blank lines must dispatch an event |
65 | | - if nextCharacter == 0x0a || nextCharacter == 0x0d { |
66 | | - if nextCharacter == 0x0d, text.getInteger(at: text.readerIndex + 1, as: UInt8.self) == 0x0a { |
67 | | - // Skip the 0x0a as well |
68 | | - // CRLF, CR and LF are all valid delimiters |
69 | | - text.moveReaderIndex(forwardBy: 2) |
70 | | - } else { |
71 | | - text.moveReaderIndex(forwardBy: 1) |
72 | | - } |
73 | | - |
74 | | - var event = ServerSentEvent(data: SSEValue(unchecked: data)) |
75 | | - event.type = type |
76 | | - event.id = id |
77 | | - events.append(event) |
78 | | - |
79 | | - // reset state |
80 | | - type = "message" |
81 | | - data.removeAll(keepingCapacity: true) |
82 | | - id = nil |
83 | | - |
84 | | - lastEventReaderIndex = text.readerIndex |
85 | | - |
86 | | - return text.readableBytes > 0 ? .nextField : .haltParsing |
87 | | - } |
88 | | - |
89 | | - return .nextField |
90 | | - } |
91 | | - |
92 | | - var lastEventReaderIndex = text.readerIndex |
93 | | - |
94 | | - repeat { |
95 | | - switch checkEndOfEventAndStream() { |
96 | | - case .nextField: |
97 | | - var value = "" |
98 | | - |
99 | | - let readableBytesView = text.readableBytesView |
100 | | - let colonIndex = readableBytesView.firstIndex(where: { byte in |
101 | | - byte == 0x3a // `:` |
102 | | - }) |
103 | | - |
104 | | - guard var lineEndingIndex = readableBytesView.firstIndex(where: { byte in |
105 | | - byte == 0x0a || byte == 0x0d // `\n` or `\r` |
106 | | - }) else { |
107 | | - // Reset to before this event, as we didn't fully process this |
108 | | - text.moveReaderIndex(to: lastEventReaderIndex) |
109 | | - return getEvents() |
110 | | - } |
111 | | - |
112 | | - // The indices are offset from the start of the buffer, not the start of the readable bytes |
113 | | - lineEndingIndex -= readableBytesView.startIndex |
114 | | - |
115 | | - if var colonIndex = colonIndex { |
116 | | - // The indices are offset from the start of the buffer, not the start of the readable bytes |
117 | | - colonIndex -= readableBytesView.startIndex |
118 | | - |
119 | | - guard let key = text.readString(length: colonIndex) else { |
120 | | - // Reset to before this event, as we didn't fully process this |
121 | | - text.moveReaderIndex(to: lastEventReaderIndex) |
122 | | - return getEvents() |
123 | | - } |
124 | | - |
125 | | - // Skip past colon |
126 | | - text.moveReaderIndex(forwardBy: 1) |
127 | | - |
128 | | - // Reduce the index by `key size + colon character` |
129 | | - lineEndingIndex -= colonIndex |
130 | | - lineEndingIndex -= 1 |
131 | | - |
132 | | - guard let readValue = text.readString(length: lineEndingIndex) else { |
133 | | - // Reset to before this event, as we didn't fully process this |
134 | | - text.moveReaderIndex(to: lastEventReaderIndex) |
135 | | - return getEvents() |
136 | | - } |
137 | | - |
138 | | - value = readValue.trimmingCharacters(in: .whitespacesAndNewlines) |
139 | | - |
140 | | - // see https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation |
141 | | - switch key { |
142 | | - case "event": |
143 | | - type = value |
144 | | - case "data": |
145 | | - data.append(value) |
146 | | - case "id": |
147 | | - id = value |
148 | | - // case "retry": |
149 | | - default: |
150 | | - () // Ignore field |
151 | | - } |
152 | | - } |
153 | | - |
154 | | - guard let byte: UInt8 = text.readInteger() else { |
155 | | - // Reset to before this event, as we didn't fully process this |
156 | | - text.moveReaderIndex(to: lastEventReaderIndex) |
157 | | - return getEvents() |
158 | | - } |
159 | | - |
160 | | - if byte == 0x0d, text.getInteger(at: text.readerIndex, as: UInt8.self) == 0x0a { |
161 | | - // Skip the 0x0a as well |
162 | | - // CRLF, CR and LF are all valid delimiters |
163 | | - text.moveReaderIndex(forwardBy: 1) |
164 | | - } |
165 | | - // TODO: What if we receive an `\r` here, and a `\n` in the next TCP read? Do we pair them up, or regard one as an empty event? |
166 | | - case .haltParsing: |
167 | | - text.moveReaderIndex(to: lastEventReaderIndex) |
168 | | - return getEvents() |
169 | | - } |
170 | | - } while text.readableBytes > 0 |
171 | | - |
172 | | - text.moveReaderIndex(to: lastEventReaderIndex) |
173 | | - return getEvents() |
| 68 | +extension AsyncSequence where Element == ByteBuffer { |
| 69 | + public func getServerSentEvents(allocator: ByteBufferAllocator) -> SSEStream { |
| 70 | + SSEStream(sequence: self, allocator: allocator) |
174 | 71 | } |
175 | 72 | } |
0 commit comments