Skip to content
Open
133 changes: 101 additions & 32 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Modal Labs 2022
import asyncio
import codecs
import contextlib
import sys
import time
from collections.abc import AsyncGenerator, AsyncIterator
from dataclasses import dataclass
Expand Down Expand Up @@ -377,7 +379,7 @@ async def _stdio_stream_from_command_router(
return


class _BytesStreamReaderThroughCommandRouter(Generic[T]):
class _BytesStreamReaderThroughCommandRouter:
"""
StreamReader implementation that will read directly from the worker that
hosts the sandbox.
Expand All @@ -396,27 +398,27 @@ def __init__(
def file_descriptor(self) -> int:
return self._params.file_descriptor

async def read(self) -> T:
async def read(self) -> bytes:
data_bytes = b""
async for part in self:
data_bytes += cast(bytes, part)
return cast(T, data_bytes)
return data_bytes

def __aiter__(self) -> AsyncIterator[T]:
def __aiter__(self) -> AsyncIterator[bytes]:
return self

async def __anext__(self) -> T:
async def __anext__(self) -> bytes:
if self._stream is None:
self._stream = _stdio_stream_from_command_router(self._params)
# This raises StopAsyncIteration if the stream is at EOF.
return cast(T, await self._stream.__anext__())
return await self._stream.__anext__()

async def aclose(self):
if self._stream:
await self._stream.aclose()


class _TextStreamReaderThroughCommandRouter(Generic[T]):
class _TextStreamReaderThroughCommandRouter:
"""
StreamReader implementation that will read directly from the worker
that hosts the sandbox.
Expand All @@ -437,30 +439,105 @@ def __init__(
def file_descriptor(self) -> int:
return self._params.file_descriptor

async def read(self) -> T:
async def read(self) -> str:
data_str = ""
async for part in self:
data_str += cast(str, part)
return cast(T, data_str)
return data_str

def __aiter__(self) -> AsyncIterator[T]:
def __aiter__(self) -> AsyncIterator[str]:
return self

async def __anext__(self) -> T:
async def __anext__(self) -> str:
if self._stream is None:
bytes_stream = _stdio_stream_from_command_router(self._params)
if self._by_line:
self._stream = _decode_bytes_stream_to_str(_stream_by_line(bytes_stream))
else:
self._stream = _decode_bytes_stream_to_str(bytes_stream)
# This raises StopAsyncIteration if the stream is at EOF.
return cast(T, await self._stream.__anext__())
return await self._stream.__anext__()

async def aclose(self):
if self._stream:
await self._stream.aclose()


class _StdoutPrintingStreamReaderThroughCommandRouter(Generic[T]):
"""
StreamReader implementation for StreamType.STDOUT when using the task command router.

This mirrors the behavior from the server-backed implementation: the stream is printed to
the local stdout immediately and is not readable via StreamReader methods.
"""

_reader: Union[_TextStreamReaderThroughCommandRouter, _BytesStreamReaderThroughCommandRouter]

def __init__(
self,
reader: Union[_TextStreamReaderThroughCommandRouter, _BytesStreamReaderThroughCommandRouter],
) -> None:
self._reader = reader
self._task: Optional[asyncio.Task[None]] = None
self._closed = False
# Kick off a background task that reads from the underlying text stream and prints to stdout.
self._start_printing_task()

@property
def file_descriptor(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a "api_pb2.FileDescriptor.ValueType" right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, the _StreamReader class has always had int as the return type here though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see... so many mini cleanups. Maybe it makes sense to add it as a quick driveby improvement to all the _StreamReaders?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i just realized... maybe the issue here is we don't want to leak this protobuf value type to users who might do type-checking?

return self._reader.file_descriptor

def _start_printing_task(self) -> None:
async def _run():
try:

def print_text_part(part: Union[str, bytes]) -> None:
assert isinstance(part, str)
print(cast(str, part), end="")

def print_bytes_part(part: Union[str, bytes]) -> None:
assert isinstance(part, bytes)
sys.stdout.buffer.write(cast(bytes, part))
sys.stdout.buffer.flush()

if isinstance(self._reader, _BytesStreamReaderThroughCommandRouter):
print_part = print_bytes_part
elif isinstance(self._reader, _TextStreamReaderThroughCommandRouter):
print_part = print_text_part
else:
raise ValueError("Unsupported reader type")

async for part in self._reader:
print_part(part)
except Exception as e:
logger.exception(f"Error printing stream: {e}")
finally:
closed, self._closed = self._closed, True
if not closed:
await self._reader.aclose()

self._task = asyncio.create_task(_run())

async def read(self) -> T:
raise InvalidError("Output can only be retrieved using the PIPE stream type.")

def __aiter__(self) -> AsyncIterator[T]:
raise InvalidError("Output can only be retrieved using the PIPE stream type.")

async def __anext__(self) -> T:
raise InvalidError("Output can only be retrieved using the PIPE stream type.")

async def aclose(self):
if self._task is not None:
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
self._task = None
closed, self._closed = self._closed, True
if not closed:
await self._reader.aclose()


class _DevnullStreamReader(Generic[T]):
"""StreamReader implementation for a stream configured with
StreamType.DEVNULL. Throws an error if read or any other method is
Expand Down Expand Up @@ -499,6 +576,7 @@ class _StreamReader(Generic[T]):
_DevnullStreamReader,
_TextStreamReaderThroughCommandRouter,
_BytesStreamReaderThroughCommandRouter,
_StdoutPrintingStreamReaderThroughCommandRouter,
]

def __init__(
Expand All @@ -523,35 +601,26 @@ def __init__(
file_descriptor, object_id, object_type, client, stream_type, text, by_line, deadline
)
else:
# The only reason task_id is optional is because StreamReader is
# also used for sandbox logs, which don't have a task ID available
# when the StreamReader is created.
# The only reason task_id is optional is because StreamReader is also used for sandbox
# logs, which don't have a task ID available when the StreamReader is created.
assert task_id is not None
assert object_type == "container_process"
if stream_type == StreamType.DEVNULL:
self._impl = _DevnullStreamReader(file_descriptor)
else:
assert stream_type == StreamType.PIPE or stream_type == StreamType.STDOUT
# TODO(saltzm): The original implementation of STDOUT StreamType in
# _StreamReaderThroughServer prints to stdout immediately. This doesn't match
# python subprocess.run, which uses None to print to stdout immediately, and uses
# STDOUT as an argument to stderr to redirect stderr to the stdout stream. We should
# implement the old behavior here before moving out of beta, but after that
# we should consider changing the API to match python subprocess.run. I don't expect
# many customers are using this in any case, so I think it's fine to leave this
# unimplemented for now.
if stream_type == StreamType.STDOUT:
raise NotImplementedError(
"Currently the STDOUT stream type is not supported when using exec "
"through a task command router, which is currently in beta."
)
params = _StreamReaderThroughCommandRouterParams(
file_descriptor, task_id, object_id, command_router_client, deadline
)
if text:
self._impl = _TextStreamReaderThroughCommandRouter(params, by_line)
reader = _TextStreamReaderThroughCommandRouter(params, by_line)
else:
reader = _BytesStreamReaderThroughCommandRouter(params)

if stream_type == StreamType.STDOUT:
self._impl = _StdoutPrintingStreamReaderThroughCommandRouter(reader)
else:
self._impl = _BytesStreamReaderThroughCommandRouter(params)
self._impl = reader

@property
def file_descriptor(self) -> int:
Expand All @@ -560,7 +629,7 @@ def file_descriptor(self) -> int:

async def read(self) -> T:
"""Fetch the entire contents of the stream until EOF."""
return await self._impl.read()
return cast(T, await self._impl.read())

# TODO(saltzm): I'd prefer to have the implementation classes only implement __aiter__
# and have them return generator functions directly, but synchronicity doesn't let us
Expand All @@ -572,7 +641,7 @@ def __aiter__(self) -> AsyncIterator[T]:

async def __anext__(self) -> T:
"""mdmd:hidden"""
return await self._impl.__anext__()
return cast(T, await self._impl.__anext__())

async def aclose(self):
"""mdmd:hidden"""
Expand Down
8 changes: 2 additions & 6 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,12 +931,8 @@ async def _exec_through_command_router(
elif stdout == StreamType.DEVNULL:
stdout_config = sr_pb2.TaskExecStdoutConfig.TASK_EXEC_STDOUT_CONFIG_DEVNULL
elif stdout == StreamType.STDOUT:
# TODO(saltzm): This is a behavior change from the old implementation. We should
# probably implement the old behavior of printing to stdout before moving out of beta.
raise NotImplementedError(
"Currently the STDOUT stream type is not supported when using exec "
"through a task command router, which is currently in beta."
)
# Stream stdout to the client so that it can be printed locally in the reader.
stdout_config = sr_pb2.TaskExecStdoutConfig.TASK_EXEC_STDOUT_CONFIG_PIPE
else:
raise ValueError("Unsupported StreamType for stdout")

Expand Down
41 changes: 40 additions & 1 deletion test/sandbox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,53 @@ def test_sandbox_gpu_fallbacks_support(client, servicer):


@skip_non_subprocess
def test_sandbox_exec_stdout(app, servicer, capsys):
@pytest.mark.parametrize("exec_backend", ["server", "router"], indirect=True)
def test_sandbox_exec_with_streamtype_stdout_and_text_true_prints_to_stdout(app, servicer, exec_backend, capsys):
sb = Sandbox.create("sleep", "infinity", app=app)

cp = sb.exec("bash", "-c", "echo hi", stdout=StreamType.STDOUT)
cp.wait()

assert capsys.readouterr().out == "hi\n"


@skip_non_subprocess
@pytest.mark.parametrize("exec_backend", ["server", "router"], indirect=True)
def test_sandbox_exec_with_streamtype_stdout_and_text_true_and_bufsize_1_prints_to_stdout(
app, servicer, exec_backend, capsys
):
sb = Sandbox.create("sleep", "infinity", app=app)

cp = sb.exec("bash", "-c", "echo hi && echo bye", stdout=StreamType.STDOUT, bufsize=1)
cp.wait()

assert capsys.readouterr().out == "hi\nbye\n"


@skip_non_subprocess
@pytest.mark.parametrize("exec_backend", ["server", "router"], indirect=True)
def test_sandbox_exec_with_streamtype_stdout_and_text_false_prints_to_stdout(app, servicer, exec_backend, capsysbinary):
sb = Sandbox.create("sleep", "infinity", app=app)

cp = sb.exec(
"bash",
"-c",
"printf '\\x01\\x02\\x03\\n\\x04\\x05\\x06\\n'",
stdout=StreamType.STDOUT,
text=False,
)
cp.wait()

assert capsysbinary.readouterr().out == b"\x01\x02\x03\n\x04\x05\x06\n"


@skip_non_subprocess
@pytest.mark.parametrize("exec_backend", ["server", "router"], indirect=True)
def test_sandbox_exec_with_streamtype_stdout_read_from_stdout_raises_error(app, servicer, exec_backend, capsys):
sb = Sandbox.create("sleep", "infinity", app=app)

cp = sb.exec("bash", "-c", "echo hi", stdout=StreamType.STDOUT)

with pytest.raises(InvalidError):
cp.stdout.read()

Expand Down
Loading