Skip to content

Commit d63c33a

Browse files
authored
Avoid echoing onto a captured FD (#1111)
1 parent a0e887a commit d63c33a

File tree

4 files changed

+190
-53
lines changed

4 files changed

+190
-53
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ jobs:
150150
uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1
151151
with:
152152
dependency_type: minimum
153+
154+
- name: List installed packages
155+
run: |
156+
hatch run test:list
157+
153158
- name: Run the unit tests
154159
run: |
155160
hatch run test:nowarn || hatch run test:nowarn --lf

ipykernel/iostream.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
echo : bool
365365
whether to echo output
366366
watchfd : bool (default, True)
367-
Watch the file descripttor corresponding to the replaced stream.
367+
Watch the file descriptor corresponding to the replaced stream.
368368
This is useful if you know some underlying code will write directly
369369
the file descriptor by its number. It will spawn a watching thread,
370370
that will swap the give file descriptor for a pipe, read from the
@@ -408,19 +408,39 @@ def __init__(
408408

409409
if (
410410
watchfd
411-
and (sys.platform.startswith("linux") or sys.platform.startswith("darwin"))
412-
and ("PYTEST_CURRENT_TEST" not in os.environ)
411+
and (
412+
(sys.platform.startswith("linux") or sys.platform.startswith("darwin"))
413+
# Pytest set its own capture. Don't redirect from within pytest.
414+
and ("PYTEST_CURRENT_TEST" not in os.environ)
415+
)
416+
# allow forcing watchfd (mainly for tests)
417+
or watchfd == "force"
413418
):
414-
# Pytest set its own capture. Dont redirect from within pytest.
415-
416419
self._should_watch = True
417420
self._setup_stream_redirects(name)
418421

419422
if echo:
420423
if hasattr(echo, "read") and hasattr(echo, "write"):
424+
# make sure we aren't trying to echo on the FD we're watching!
425+
# that would cause an infinite loop, always echoing on itself
426+
if self._should_watch:
427+
try:
428+
echo_fd = echo.fileno()
429+
except Exception:
430+
echo_fd = None
431+
432+
if echo_fd is not None and echo_fd == self._original_stdstream_fd:
433+
# echo on the _copy_ we made during
434+
# this is the actual terminal FD now
435+
echo = io.TextIOWrapper(
436+
io.FileIO(
437+
self._original_stdstream_copy,
438+
"w",
439+
)
440+
)
421441
self.echo = echo
422442
else:
423-
msg = "echo argument must be a file like object"
443+
msg = "echo argument must be a file-like object"
424444
raise ValueError(msg)
425445

426446
def isatty(self):
@@ -433,7 +453,7 @@ def isatty(self):
433453

434454
def _setup_stream_redirects(self, name):
435455
pr, pw = os.pipe()
436-
fno = getattr(sys, name).fileno()
456+
fno = self._original_stdstream_fd = getattr(sys, name).fileno()
437457
self._original_stdstream_copy = os.dup(fno)
438458
os.dup2(pw, fno)
439459

@@ -455,7 +475,13 @@ def close(self):
455475
"""Close the stream."""
456476
if self._should_watch:
457477
self._should_watch = False
478+
# thread won't wake unless there's something to read
479+
# writing something after _should_watch will not be echoed
480+
os.write(self._original_stdstream_fd, b'\0')
458481
self.watch_fd_thread.join()
482+
# restore original FDs
483+
os.dup2(self._original_stdstream_copy, self._original_stdstream_fd)
484+
os.close(self._original_stdstream_copy)
459485
if self._exc:
460486
etype, value, tb = self._exc
461487
traceback.print_exception(etype, value, tb)

ipykernel/tests/test_io.py

Lines changed: 151 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
"""Test IO capturing functionality"""
22

33
import io
4+
import os
5+
import subprocess
6+
import sys
7+
import time
48
import warnings
9+
from unittest import mock
510

611
import pytest
712
import zmq
@@ -10,20 +15,28 @@
1015
from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream
1116

1217

13-
def test_io_api():
14-
"""Test that wrapped stdout has the same API as a normal TextIO object"""
15-
session = Session()
18+
@pytest.fixture
19+
def ctx():
1620
ctx = zmq.Context()
17-
pub = ctx.socket(zmq.PUB)
18-
thread = IOPubThread(pub)
19-
thread.start()
21+
yield ctx
22+
ctx.destroy()
2023

21-
stream = OutStream(session, thread, "stdout")
2224

23-
# cleanup unused zmq objects before we start testing
24-
thread.stop()
25-
thread.close()
26-
ctx.term()
25+
@pytest.fixture
26+
def iopub_thread(ctx):
27+
with ctx.socket(zmq.PUB) as pub:
28+
thread = IOPubThread(pub)
29+
thread.start()
30+
31+
yield thread
32+
thread.stop()
33+
thread.close()
34+
35+
36+
def test_io_api(iopub_thread):
37+
"""Test that wrapped stdout has the same API as a normal TextIO object"""
38+
session = Session()
39+
stream = OutStream(session, iopub_thread, "stdout")
2740

2841
assert stream.errors is None
2942
assert not stream.isatty()
@@ -43,69 +56,161 @@ def test_io_api():
4356
stream.write(b"") # type:ignore
4457

4558

46-
def test_io_isatty():
59+
def test_io_isatty(iopub_thread):
4760
session = Session()
48-
ctx = zmq.Context()
49-
pub = ctx.socket(zmq.PUB)
50-
thread = IOPubThread(pub)
51-
thread.start()
52-
53-
stream = OutStream(session, thread, "stdout", isatty=True)
61+
stream = OutStream(session, iopub_thread, "stdout", isatty=True)
5462
assert stream.isatty()
5563

5664

57-
def test_io_thread():
58-
ctx = zmq.Context()
59-
pub = ctx.socket(zmq.PUB)
60-
thread = IOPubThread(pub)
65+
def test_io_thread(iopub_thread):
66+
thread = iopub_thread
6167
thread._setup_pipe_in()
6268
msg = [thread._pipe_uuid, b"a"]
6369
thread._handle_pipe_msg(msg)
6470
ctx1, pipe = thread._setup_pipe_out()
6571
pipe.close()
6672
thread._pipe_in.close()
67-
thread._check_mp_mode = lambda: MASTER # type:ignore
73+
thread._check_mp_mode = lambda: MASTER
6874
thread._really_send([b"hi"])
6975
ctx1.destroy()
7076
thread.close()
7177
thread.close()
7278
thread._really_send(None)
7379

7480

75-
def test_background_socket():
76-
ctx = zmq.Context()
77-
pub = ctx.socket(zmq.PUB)
78-
thread = IOPubThread(pub)
79-
sock = BackgroundSocket(thread)
81+
def test_background_socket(iopub_thread):
82+
sock = BackgroundSocket(iopub_thread)
8083
assert sock.__class__ == BackgroundSocket
8184
with warnings.catch_warnings():
8285
warnings.simplefilter("ignore", DeprecationWarning)
8386
sock.linger = 101
84-
assert thread.socket.linger == 101
85-
assert sock.io_thread == thread
87+
assert iopub_thread.socket.linger == 101
88+
assert sock.io_thread == iopub_thread
8689
sock.send(b"hi")
8790

8891

89-
def test_outstream():
92+
def test_outstream(iopub_thread):
9093
session = Session()
91-
ctx = zmq.Context()
92-
pub = ctx.socket(zmq.PUB)
93-
thread = IOPubThread(pub)
94-
thread.start()
95-
94+
pub = iopub_thread.socket
9695
with warnings.catch_warnings():
9796
warnings.simplefilter("ignore", DeprecationWarning)
9897
stream = OutStream(session, pub, "stdout")
99-
stream = OutStream(session, thread, "stdout", pipe=object())
98+
stream.close()
99+
stream = OutStream(session, iopub_thread, "stdout", pipe=object())
100+
stream.close()
100101

101-
stream = OutStream(session, thread, "stdout", watchfd=False)
102+
stream = OutStream(session, iopub_thread, "stdout", watchfd=False)
102103
stream.close()
103104

104-
stream = OutStream(session, thread, "stdout", isatty=True, echo=io.StringIO())
105-
with pytest.raises(io.UnsupportedOperation):
106-
stream.fileno()
107-
stream._watch_pipe_fd()
108-
stream.flush()
109-
stream.write("hi")
110-
stream.writelines(["ab", "cd"])
111-
assert stream.writable()
105+
stream = OutStream(session, iopub_thread, "stdout", isatty=True, echo=io.StringIO())
106+
107+
with stream:
108+
with pytest.raises(io.UnsupportedOperation):
109+
stream.fileno()
110+
stream._watch_pipe_fd()
111+
stream.flush()
112+
stream.write("hi")
113+
stream.writelines(["ab", "cd"])
114+
assert stream.writable()
115+
116+
117+
def subprocess_test_echo_watch():
118+
# handshake Pub subscription
119+
session = Session(key=b'abc')
120+
121+
# use PUSH socket to avoid subscription issues
122+
with zmq.Context() as ctx, ctx.socket(zmq.PUSH) as pub:
123+
pub.connect(os.environ["IOPUB_URL"])
124+
iopub_thread = IOPubThread(pub)
125+
iopub_thread.start()
126+
stdout_fd = sys.stdout.fileno()
127+
sys.stdout.flush()
128+
stream = OutStream(
129+
session,
130+
iopub_thread,
131+
"stdout",
132+
isatty=True,
133+
echo=sys.stdout,
134+
watchfd="force",
135+
)
136+
save_stdout = sys.stdout
137+
with stream, mock.patch.object(sys, "stdout", stream):
138+
# write to low-level FD
139+
os.write(stdout_fd, b"fd\n")
140+
# print (writes to stream)
141+
print("print\n", end="")
142+
sys.stdout.flush()
143+
# write to unwrapped __stdout__ (should also go to original FD)
144+
sys.__stdout__.write("__stdout__\n")
145+
sys.__stdout__.flush()
146+
# write to original sys.stdout (should be the same as __stdout__)
147+
save_stdout.write("stdout\n")
148+
save_stdout.flush()
149+
# is there another way to flush on the FD?
150+
fd_file = os.fdopen(stdout_fd, "w")
151+
fd_file.flush()
152+
# we don't have a sync flush on _reading_ from the watched pipe
153+
time.sleep(1)
154+
stream.flush()
155+
iopub_thread.stop()
156+
iopub_thread.close()
157+
158+
159+
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows")
160+
def test_echo_watch(ctx):
161+
"""Test echo on underlying FD while capturing the same FD
162+
163+
Test runs in a subprocess to avoid messing with pytest output capturing.
164+
"""
165+
s = ctx.socket(zmq.PULL)
166+
port = s.bind_to_random_port("tcp://127.0.0.1")
167+
url = f"tcp://127.0.0.1:{port}"
168+
session = Session(key=b'abc')
169+
messages = []
170+
stdout_chunks = []
171+
with s:
172+
env = dict(os.environ)
173+
env["IOPUB_URL"] = url
174+
env["PYTHONUNBUFFERED"] = "1"
175+
env.pop("PYTEST_CURRENT_TEST", None)
176+
p = subprocess.run(
177+
[
178+
sys.executable,
179+
"-c",
180+
f"import {__name__}; {__name__}.subprocess_test_echo_watch()",
181+
],
182+
env=env,
183+
capture_output=True,
184+
text=True,
185+
timeout=10,
186+
)
187+
print(f"{p.stdout=}")
188+
print(f"{p.stderr}=", file=sys.stderr)
189+
assert p.returncode == 0
190+
while s.poll(timeout=100):
191+
ident, msg = session.recv(s)
192+
assert msg is not None # for type narrowing
193+
if msg["header"]["msg_type"] == "stream" and msg["content"]["name"] == "stdout":
194+
stdout_chunks.append(msg["content"]["text"])
195+
196+
# check outputs
197+
# use sets of lines to ignore ordering issues with
198+
# async flush and watchfd thread
199+
200+
# Check the stream output forwarded over zmq
201+
zmq_stdout = "".join(stdout_chunks)
202+
assert set(zmq_stdout.strip().splitlines()) == {
203+
"fd",
204+
"print",
205+
"stdout",
206+
"__stdout__",
207+
}
208+
209+
# Check what was written to the process stdout (kernel terminal)
210+
# just check that each output source went to the terminal
211+
assert set(p.stdout.strip().splitlines()) == {
212+
"fd",
213+
"print",
214+
"stdout",
215+
"__stdout__",
216+
}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ api = "sphinx-apidoc -o docs/api -f -E ipykernel ipykernel/tests ipykernel/inpro
9191
[tool.hatch.envs.test]
9292
features = ["test"]
9393
[tool.hatch.envs.test.scripts]
94+
list = "python -m pip freeze"
9495
test = "python -m pytest -vv {args}"
9596
nowarn = "test -W default {args}"
9697

0 commit comments

Comments
 (0)