Skip to content

Commit 875e1f7

Browse files
Add SQL implementation of SandboxedConversationService
- Created SQLSandboxedConversationService with full CRUD operations - Implemented batch operations for efficient data retrieval - Added HTTP client integration for agent status from OpenHands Agent Server - Created SQLSandboxedConversationServiceResolver for dependency injection - Updated dependency.py to use real SQL implementation instead of MagicMock - Fixed imports to use correct AgentState from openhands.core.schema - Modeled after existing sql_user_service.py patterns Co-authored-by: openhands <[email protected]>
1 parent 129d42a commit 875e1f7

File tree

3 files changed

+354
-3
lines changed

3 files changed

+354
-3
lines changed

openhands_server/dependency.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ def _get_sandbox_spec_service_factory():
9999

100100

101101
def _get_sandboxed_conversation_service_factory():
102-
return MagicMock() # TODO: Replace with real implementation!
102+
from openhands_server.sandboxed_conversation.sql_sandboxed_conversation_service import (
103+
SQLSandboxedConversationServiceResolver,
104+
)
105+
106+
return SQLSandboxedConversationServiceResolver()
103107

104108

105109
def _get_user_service_factory():

openhands_server/sandboxed_conversation/sandboxed_conversation_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel, Field
66
from sqlmodel import Field as SQLField, SQLModel
77

8-
from openhands.sdk.conversation.state import AgentExecutionStatus
8+
from openhands.core.schema import AgentState
99
from openhands_server.event_callback.event_callback_models import EventCallbackProcessor
1010
from openhands_server.sandbox.sandbox_models import SandboxStatus
1111
from openhands_server.utils.date_utils import utc_now
@@ -25,7 +25,7 @@ class StoredConversationInfo(SQLModel):
2525

2626
class SandboxedConversationResponse(StoredConversationInfo):
2727
sandbox_status: SandboxStatus
28-
agent_status: AgentExecutionStatus
28+
agent_status: AgentState | None
2929

3030

3131
class SandboxedConversationResponseSortOrder(Enum):
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# pyright: reportArgumentType=false, reportAttributeAccessIssue=false, reportOptionalMemberAccess=false
2+
"""SQL implementation of SandboxedConversationService.
3+
4+
This implementation provides CRUD operations for sandboxed conversations focused purely on SQL operations:
5+
- Direct database access without permission checks
6+
- Batch operations for efficient data retrieval
7+
- Integration with SandboxService for sandbox information
8+
- HTTP client integration for agent status retrieval
9+
- Full async/await support using SQL async sessions
10+
11+
Security and permission checks are handled by wrapper services.
12+
13+
Key components:
14+
- SQLSandboxedConversationService: Main service class implementing all operations
15+
- SQLSandboxedConversationServiceResolver: Dependency injection resolver for FastAPI
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import asyncio
21+
import logging
22+
from dataclasses import dataclass
23+
from datetime import datetime
24+
from typing import Callable
25+
from uuid import UUID
26+
27+
import httpx
28+
from fastapi import Depends
29+
from sqlalchemy import func, select
30+
from sqlalchemy.ext.asyncio import AsyncSession
31+
32+
from openhands.core.schema import AgentState
33+
from openhands_server.database import async_session_dependency
34+
from openhands_server.sandbox.sandbox_models import AGENT_SERVER, SandboxStatus
35+
from openhands_server.sandbox.sandbox_service import SandboxService, SandboxServiceResolver
36+
from openhands_server.sandboxed_conversation.sandboxed_conversation_models import (
37+
SandboxedConversationResponse,
38+
SandboxedConversationResponsePage,
39+
StartSandboxedConversationRequest,
40+
StoredConversationInfo,
41+
)
42+
from openhands_server.sandboxed_conversation.sandboxed_conversation_service import (
43+
SandboxedConversationService,
44+
SandboxedConversationServiceResolver,
45+
)
46+
from openhands_server.utils.date_utils import utc_now
47+
48+
49+
logger = logging.getLogger(__name__)
50+
51+
52+
@dataclass
53+
class SQLSandboxedConversationService(SandboxedConversationService):
54+
"""SQL implementation of SandboxedConversationService focused on database operations."""
55+
56+
session: AsyncSession
57+
sandbox_service: SandboxService
58+
59+
async def search_sandboxed_conversations(
60+
self,
61+
title__contains: str | None = None,
62+
created_at__gte: datetime | None = None,
63+
created_at__lt: datetime | None = None,
64+
updated_at__gte: datetime | None = None,
65+
updated_at__lt: datetime | None = None,
66+
page_id: str | None = None,
67+
limit: int = 100,
68+
) -> SandboxedConversationResponsePage:
69+
"""Search for sandboxed conversations without permission checks."""
70+
query = select(StoredConversationInfo)
71+
72+
# Apply filters
73+
conditions = []
74+
if title__contains is not None:
75+
conditions.append(StoredConversationInfo.title.like(f"%{title__contains}%"))
76+
77+
if created_at__gte is not None:
78+
conditions.append(StoredConversationInfo.created_at >= created_at__gte)
79+
80+
if created_at__lt is not None:
81+
conditions.append(StoredConversationInfo.created_at < created_at__lt)
82+
83+
if updated_at__gte is not None:
84+
conditions.append(StoredConversationInfo.updated_at >= updated_at__gte)
85+
86+
if updated_at__lt is not None:
87+
conditions.append(StoredConversationInfo.updated_at < updated_at__lt)
88+
89+
if conditions:
90+
query = query.where(*conditions)
91+
92+
# Apply pagination
93+
if page_id is not None:
94+
try:
95+
offset = int(page_id)
96+
query = query.offset(offset)
97+
except ValueError:
98+
# If page_id is not a valid integer, start from beginning
99+
offset = 0
100+
else:
101+
offset = 0
102+
103+
# Apply sorting (default to created_at desc)
104+
query = query.order_by(StoredConversationInfo.created_at.desc())
105+
106+
# Apply limit and get one extra to check if there are more results
107+
query = query.limit(limit + 1)
108+
109+
result = await self.session.execute(query)
110+
stored_conversations = list(result.scalars().all())
111+
112+
# Check if there are more results
113+
has_more = len(stored_conversations) > limit
114+
if has_more:
115+
stored_conversations = stored_conversations[:limit]
116+
117+
# Calculate next page ID
118+
next_page_id = None
119+
if has_more:
120+
next_page_id = str(offset + limit)
121+
122+
# Build responses with sandbox and agent status
123+
responses = await self._build_conversation_responses(stored_conversations)
124+
125+
return SandboxedConversationResponsePage(items=responses, next_page_id=next_page_id)
126+
127+
async def count_sandboxed_conversations(
128+
self,
129+
title__contains: str | None = None,
130+
created_at__gte: datetime | None = None,
131+
created_at__lt: datetime | None = None,
132+
updated_at__gte: datetime | None = None,
133+
updated_at__lt: datetime | None = None,
134+
) -> int:
135+
"""Count sandboxed conversations matching the given filters."""
136+
query = select(func.count(StoredConversationInfo.id))
137+
138+
# Apply the same filters as search_sandboxed_conversations
139+
conditions = []
140+
if title__contains is not None:
141+
conditions.append(StoredConversationInfo.title.like(f"%{title__contains}%"))
142+
143+
if created_at__gte is not None:
144+
conditions.append(StoredConversationInfo.created_at >= created_at__gte)
145+
146+
if created_at__lt is not None:
147+
conditions.append(StoredConversationInfo.created_at < created_at__lt)
148+
149+
if updated_at__gte is not None:
150+
conditions.append(StoredConversationInfo.updated_at >= updated_at__gte)
151+
152+
if updated_at__lt is not None:
153+
conditions.append(StoredConversationInfo.updated_at < updated_at__lt)
154+
155+
if conditions:
156+
query = query.where(*conditions)
157+
158+
result = await self.session.execute(query)
159+
count = result.scalar()
160+
return count or 0
161+
162+
async def get_sandboxed_conversation(
163+
self, conversation_id: UUID
164+
) -> SandboxedConversationResponse | None:
165+
"""Get a single sandboxed conversation info. Return None if the conversation was not found."""
166+
query = select(StoredConversationInfo).where(StoredConversationInfo.id == conversation_id)
167+
result = await self.session.execute(query)
168+
stored_conversation = result.scalar_one_or_none()
169+
170+
if stored_conversation is None:
171+
return None
172+
173+
# Build response with sandbox and agent status
174+
responses = await self._build_conversation_responses([stored_conversation])
175+
return responses[0] if responses else None
176+
177+
async def start_sandboxed_conversation(
178+
self, request: StartSandboxedConversationRequest
179+
) -> SandboxedConversationResponse:
180+
"""Start a conversation, optionally specifying a sandbox in which to start."""
181+
# For now, this is a placeholder implementation
182+
# In a real implementation, this would:
183+
# 1. Create or get a sandbox
184+
# 2. Start a conversation in that sandbox
185+
# 3. Set up event callbacks
186+
# 4. Return the conversation response
187+
raise NotImplementedError("start_sandboxed_conversation not yet implemented")
188+
189+
async def _build_conversation_responses(
190+
self, stored_conversations: list[StoredConversationInfo]
191+
) -> list[SandboxedConversationResponse]:
192+
"""Build conversation responses with sandbox and agent status information."""
193+
if not stored_conversations:
194+
return []
195+
196+
# Extract unique sandbox IDs
197+
sandbox_ids = list(set(conv.sandbox_id for conv in stored_conversations))
198+
199+
# Batch get sandbox information
200+
sandbox_infos = await self.sandbox_service.batch_get_sandboxes(sandbox_ids)
201+
sandbox_info_map = {
202+
info.id: info for info in sandbox_infos if info is not None
203+
}
204+
205+
# Group conversations by sandbox for efficient agent status retrieval
206+
conversations_by_sandbox = {}
207+
for conv in stored_conversations:
208+
if conv.sandbox_id not in conversations_by_sandbox:
209+
conversations_by_sandbox[conv.sandbox_id] = []
210+
conversations_by_sandbox[conv.sandbox_id].append(conv)
211+
212+
# Batch get agent status for running sandboxes
213+
agent_status_tasks = []
214+
sandbox_to_task_map = {}
215+
216+
for sandbox_id, conversations in conversations_by_sandbox.items():
217+
sandbox_info = sandbox_info_map.get(sandbox_id)
218+
if sandbox_info and sandbox_info.status == SandboxStatus.RUNNING:
219+
# Find the AGENT_SERVER URL
220+
agent_server_url = None
221+
if sandbox_info.exposed_urls:
222+
for exposed_url in sandbox_info.exposed_urls:
223+
if exposed_url.name == AGENT_SERVER:
224+
agent_server_url = exposed_url.url
225+
break
226+
227+
if agent_server_url:
228+
conversation_ids = [str(conv.id) for conv in conversations]
229+
task = self._get_agent_status_for_conversations(
230+
agent_server_url, conversation_ids, sandbox_info.session_api_key
231+
)
232+
agent_status_tasks.append(task)
233+
sandbox_to_task_map[sandbox_id] = len(agent_status_tasks) - 1
234+
235+
# Execute all agent status requests in parallel
236+
agent_status_results = []
237+
if agent_status_tasks:
238+
agent_status_results = await asyncio.gather(*agent_status_tasks, return_exceptions=True)
239+
240+
# Build the final responses
241+
responses = []
242+
for conv in stored_conversations:
243+
sandbox_info = sandbox_info_map.get(conv.sandbox_id)
244+
sandbox_status = sandbox_info.status if sandbox_info else SandboxStatus.ERROR
245+
246+
# Determine agent status
247+
agent_status = None
248+
if (sandbox_info and
249+
sandbox_info.status == SandboxStatus.RUNNING and
250+
conv.sandbox_id in sandbox_to_task_map):
251+
252+
task_index = sandbox_to_task_map[conv.sandbox_id]
253+
if task_index < len(agent_status_results):
254+
result = agent_status_results[task_index]
255+
if not isinstance(result, Exception):
256+
agent_status = result.get(str(conv.id))
257+
258+
response = SandboxedConversationResponse(
259+
id=conv.id,
260+
title=conv.title,
261+
sandbox_id=conv.sandbox_id,
262+
created_at=conv.created_at,
263+
updated_at=conv.updated_at,
264+
sandbox_status=sandbox_status,
265+
agent_status=agent_status,
266+
)
267+
responses.append(response)
268+
269+
return responses
270+
271+
async def _get_agent_status_for_conversations(
272+
self, agent_server_url: str, conversation_ids: list[str], session_api_key: str | None
273+
) -> dict[str, AgentState]:
274+
"""Get agent status for multiple conversations from the OpenHands Agent Server."""
275+
try:
276+
# Build the URL with query parameters
277+
url = f"{agent_server_url.rstrip('/')}/conversations"
278+
params = {"ids": conversation_ids}
279+
280+
# Set up headers
281+
headers = {}
282+
if session_api_key:
283+
headers["X-Session-API-Key"] = session_api_key
284+
285+
async with httpx.AsyncClient(timeout=10.0) as client:
286+
response = await client.get(url, params=params, headers=headers)
287+
response.raise_for_status()
288+
289+
data = response.json()
290+
291+
# Extract agent status for each conversation
292+
agent_statuses = {}
293+
if isinstance(data, list):
294+
for conversation_data in data:
295+
if isinstance(conversation_data, dict):
296+
conv_id = conversation_data.get("id")
297+
status_str = conversation_data.get("agent_status")
298+
if conv_id and status_str:
299+
try:
300+
agent_status = AgentState(status_str)
301+
agent_statuses[conv_id] = agent_status
302+
except ValueError:
303+
logger.warning(f"Invalid agent status: {status_str}")
304+
305+
return agent_statuses
306+
307+
except Exception as e:
308+
logger.warning(f"Failed to get agent status from {agent_server_url}: {e}")
309+
return {}
310+
311+
312+
class SQLSandboxedConversationServiceResolver(SandboxedConversationServiceResolver):
313+
def get_unsecured_resolver(self) -> Callable:
314+
from openhands_server.dependency import get_dependency_resolver
315+
316+
sandbox_service_resolver = (
317+
get_dependency_resolver().sandbox.get_unsecured_resolver()
318+
)
319+
320+
# Define inline to prevent circular lookup
321+
def resolve_sandboxed_conversation_service(
322+
session: AsyncSession = Depends(async_session_dependency),
323+
sandbox_service: SandboxService = Depends(sandbox_service_resolver),
324+
) -> SandboxedConversationService:
325+
return SQLSandboxedConversationService(session, sandbox_service)
326+
327+
return resolve_sandboxed_conversation_service
328+
329+
def get_resolver_for_user(self) -> Callable:
330+
from openhands_server.dependency import get_dependency_resolver
331+
332+
sandbox_service_resolver = (
333+
get_dependency_resolver().sandbox.get_resolver_for_user()
334+
)
335+
336+
# Define inline to prevent circular lookup
337+
def resolve_sandboxed_conversation_service(
338+
session: AsyncSession = Depends(async_session_dependency),
339+
sandbox_service: SandboxService = Depends(sandbox_service_resolver),
340+
) -> SandboxedConversationService:
341+
service = SQLSandboxedConversationService(session, sandbox_service)
342+
# TODO: Add auth and fix
343+
logger.warning("⚠️ Using Unsecured SandboxedConversationService!!!")
344+
# service = ConstrainedSandboxedConversationService(service, self.current_user_id)
345+
return service
346+
347+
return resolve_sandboxed_conversation_service

0 commit comments

Comments
 (0)