1+ # pyright: reportArgumentType=false, reportAttributeAccessIssue=false, reportOptionalMemberAccess=false
2+ """SQL implementation of SandboxedConversationService.
3+
4+ This implementation provides CRUD operations for sandboxed conversations focused purely on SQL operations:
5+ - Direct database access without permission checks
6+ - Batch operations for efficient data retrieval
7+ - Integration with SandboxService for sandbox information
8+ - HTTP client integration for agent status retrieval
9+ - Full async/await support using SQL async sessions
10+
11+ Security and permission checks are handled by wrapper services.
12+
13+ Key components:
14+ - SQLSandboxedConversationService: Main service class implementing all operations
15+ - SQLSandboxedConversationServiceResolver: Dependency injection resolver for FastAPI
16+ """
17+
18+ from __future__ import annotations
19+
20+ import asyncio
21+ import logging
22+ from dataclasses import dataclass
23+ from datetime import datetime
24+ from typing import Callable
25+ from uuid import UUID
26+
27+ import httpx
28+ from fastapi import Depends
29+ from sqlalchemy import func , select
30+ from sqlalchemy .ext .asyncio import AsyncSession
31+
32+ from openhands .core .schema import AgentState
33+ from openhands_server .database import async_session_dependency
34+ from openhands_server .sandbox .sandbox_models import AGENT_SERVER , SandboxStatus
35+ from openhands_server .sandbox .sandbox_service import SandboxService , SandboxServiceResolver
36+ from openhands_server .sandboxed_conversation .sandboxed_conversation_models import (
37+ SandboxedConversationResponse ,
38+ SandboxedConversationResponsePage ,
39+ StartSandboxedConversationRequest ,
40+ StoredConversationInfo ,
41+ )
42+ from openhands_server .sandboxed_conversation .sandboxed_conversation_service import (
43+ SandboxedConversationService ,
44+ SandboxedConversationServiceResolver ,
45+ )
46+ from openhands_server .utils .date_utils import utc_now
47+
48+
49+ logger = logging .getLogger (__name__ )
50+
51+
52+ @dataclass
53+ class SQLSandboxedConversationService (SandboxedConversationService ):
54+ """SQL implementation of SandboxedConversationService focused on database operations."""
55+
56+ session : AsyncSession
57+ sandbox_service : SandboxService
58+
59+ async def search_sandboxed_conversations (
60+ self ,
61+ title__contains : str | None = None ,
62+ created_at__gte : datetime | None = None ,
63+ created_at__lt : datetime | None = None ,
64+ updated_at__gte : datetime | None = None ,
65+ updated_at__lt : datetime | None = None ,
66+ page_id : str | None = None ,
67+ limit : int = 100 ,
68+ ) -> SandboxedConversationResponsePage :
69+ """Search for sandboxed conversations without permission checks."""
70+ query = select (StoredConversationInfo )
71+
72+ # Apply filters
73+ conditions = []
74+ if title__contains is not None :
75+ conditions .append (StoredConversationInfo .title .like (f"%{ title__contains } %" ))
76+
77+ if created_at__gte is not None :
78+ conditions .append (StoredConversationInfo .created_at >= created_at__gte )
79+
80+ if created_at__lt is not None :
81+ conditions .append (StoredConversationInfo .created_at < created_at__lt )
82+
83+ if updated_at__gte is not None :
84+ conditions .append (StoredConversationInfo .updated_at >= updated_at__gte )
85+
86+ if updated_at__lt is not None :
87+ conditions .append (StoredConversationInfo .updated_at < updated_at__lt )
88+
89+ if conditions :
90+ query = query .where (* conditions )
91+
92+ # Apply pagination
93+ if page_id is not None :
94+ try :
95+ offset = int (page_id )
96+ query = query .offset (offset )
97+ except ValueError :
98+ # If page_id is not a valid integer, start from beginning
99+ offset = 0
100+ else :
101+ offset = 0
102+
103+ # Apply sorting (default to created_at desc)
104+ query = query .order_by (StoredConversationInfo .created_at .desc ())
105+
106+ # Apply limit and get one extra to check if there are more results
107+ query = query .limit (limit + 1 )
108+
109+ result = await self .session .execute (query )
110+ stored_conversations = list (result .scalars ().all ())
111+
112+ # Check if there are more results
113+ has_more = len (stored_conversations ) > limit
114+ if has_more :
115+ stored_conversations = stored_conversations [:limit ]
116+
117+ # Calculate next page ID
118+ next_page_id = None
119+ if has_more :
120+ next_page_id = str (offset + limit )
121+
122+ # Build responses with sandbox and agent status
123+ responses = await self ._build_conversation_responses (stored_conversations )
124+
125+ return SandboxedConversationResponsePage (items = responses , next_page_id = next_page_id )
126+
127+ async def count_sandboxed_conversations (
128+ self ,
129+ title__contains : str | None = None ,
130+ created_at__gte : datetime | None = None ,
131+ created_at__lt : datetime | None = None ,
132+ updated_at__gte : datetime | None = None ,
133+ updated_at__lt : datetime | None = None ,
134+ ) -> int :
135+ """Count sandboxed conversations matching the given filters."""
136+ query = select (func .count (StoredConversationInfo .id ))
137+
138+ # Apply the same filters as search_sandboxed_conversations
139+ conditions = []
140+ if title__contains is not None :
141+ conditions .append (StoredConversationInfo .title .like (f"%{ title__contains } %" ))
142+
143+ if created_at__gte is not None :
144+ conditions .append (StoredConversationInfo .created_at >= created_at__gte )
145+
146+ if created_at__lt is not None :
147+ conditions .append (StoredConversationInfo .created_at < created_at__lt )
148+
149+ if updated_at__gte is not None :
150+ conditions .append (StoredConversationInfo .updated_at >= updated_at__gte )
151+
152+ if updated_at__lt is not None :
153+ conditions .append (StoredConversationInfo .updated_at < updated_at__lt )
154+
155+ if conditions :
156+ query = query .where (* conditions )
157+
158+ result = await self .session .execute (query )
159+ count = result .scalar ()
160+ return count or 0
161+
162+ async def get_sandboxed_conversation (
163+ self , conversation_id : UUID
164+ ) -> SandboxedConversationResponse | None :
165+ """Get a single sandboxed conversation info. Return None if the conversation was not found."""
166+ query = select (StoredConversationInfo ).where (StoredConversationInfo .id == conversation_id )
167+ result = await self .session .execute (query )
168+ stored_conversation = result .scalar_one_or_none ()
169+
170+ if stored_conversation is None :
171+ return None
172+
173+ # Build response with sandbox and agent status
174+ responses = await self ._build_conversation_responses ([stored_conversation ])
175+ return responses [0 ] if responses else None
176+
177+ async def start_sandboxed_conversation (
178+ self , request : StartSandboxedConversationRequest
179+ ) -> SandboxedConversationResponse :
180+ """Start a conversation, optionally specifying a sandbox in which to start."""
181+ # For now, this is a placeholder implementation
182+ # In a real implementation, this would:
183+ # 1. Create or get a sandbox
184+ # 2. Start a conversation in that sandbox
185+ # 3. Set up event callbacks
186+ # 4. Return the conversation response
187+ raise NotImplementedError ("start_sandboxed_conversation not yet implemented" )
188+
189+ async def _build_conversation_responses (
190+ self , stored_conversations : list [StoredConversationInfo ]
191+ ) -> list [SandboxedConversationResponse ]:
192+ """Build conversation responses with sandbox and agent status information."""
193+ if not stored_conversations :
194+ return []
195+
196+ # Extract unique sandbox IDs
197+ sandbox_ids = list (set (conv .sandbox_id for conv in stored_conversations ))
198+
199+ # Batch get sandbox information
200+ sandbox_infos = await self .sandbox_service .batch_get_sandboxes (sandbox_ids )
201+ sandbox_info_map = {
202+ info .id : info for info in sandbox_infos if info is not None
203+ }
204+
205+ # Group conversations by sandbox for efficient agent status retrieval
206+ conversations_by_sandbox = {}
207+ for conv in stored_conversations :
208+ if conv .sandbox_id not in conversations_by_sandbox :
209+ conversations_by_sandbox [conv .sandbox_id ] = []
210+ conversations_by_sandbox [conv .sandbox_id ].append (conv )
211+
212+ # Batch get agent status for running sandboxes
213+ agent_status_tasks = []
214+ sandbox_to_task_map = {}
215+
216+ for sandbox_id , conversations in conversations_by_sandbox .items ():
217+ sandbox_info = sandbox_info_map .get (sandbox_id )
218+ if sandbox_info and sandbox_info .status == SandboxStatus .RUNNING :
219+ # Find the AGENT_SERVER URL
220+ agent_server_url = None
221+ if sandbox_info .exposed_urls :
222+ for exposed_url in sandbox_info .exposed_urls :
223+ if exposed_url .name == AGENT_SERVER :
224+ agent_server_url = exposed_url .url
225+ break
226+
227+ if agent_server_url :
228+ conversation_ids = [str (conv .id ) for conv in conversations ]
229+ task = self ._get_agent_status_for_conversations (
230+ agent_server_url , conversation_ids , sandbox_info .session_api_key
231+ )
232+ agent_status_tasks .append (task )
233+ sandbox_to_task_map [sandbox_id ] = len (agent_status_tasks ) - 1
234+
235+ # Execute all agent status requests in parallel
236+ agent_status_results = []
237+ if agent_status_tasks :
238+ agent_status_results = await asyncio .gather (* agent_status_tasks , return_exceptions = True )
239+
240+ # Build the final responses
241+ responses = []
242+ for conv in stored_conversations :
243+ sandbox_info = sandbox_info_map .get (conv .sandbox_id )
244+ sandbox_status = sandbox_info .status if sandbox_info else SandboxStatus .ERROR
245+
246+ # Determine agent status
247+ agent_status = None
248+ if (sandbox_info and
249+ sandbox_info .status == SandboxStatus .RUNNING and
250+ conv .sandbox_id in sandbox_to_task_map ):
251+
252+ task_index = sandbox_to_task_map [conv .sandbox_id ]
253+ if task_index < len (agent_status_results ):
254+ result = agent_status_results [task_index ]
255+ if not isinstance (result , Exception ):
256+ agent_status = result .get (str (conv .id ))
257+
258+ response = SandboxedConversationResponse (
259+ id = conv .id ,
260+ title = conv .title ,
261+ sandbox_id = conv .sandbox_id ,
262+ created_at = conv .created_at ,
263+ updated_at = conv .updated_at ,
264+ sandbox_status = sandbox_status ,
265+ agent_status = agent_status ,
266+ )
267+ responses .append (response )
268+
269+ return responses
270+
271+ async def _get_agent_status_for_conversations (
272+ self , agent_server_url : str , conversation_ids : list [str ], session_api_key : str | None
273+ ) -> dict [str , AgentState ]:
274+ """Get agent status for multiple conversations from the OpenHands Agent Server."""
275+ try :
276+ # Build the URL with query parameters
277+ url = f"{ agent_server_url .rstrip ('/' )} /conversations"
278+ params = {"ids" : conversation_ids }
279+
280+ # Set up headers
281+ headers = {}
282+ if session_api_key :
283+ headers ["X-Session-API-Key" ] = session_api_key
284+
285+ async with httpx .AsyncClient (timeout = 10.0 ) as client :
286+ response = await client .get (url , params = params , headers = headers )
287+ response .raise_for_status ()
288+
289+ data = response .json ()
290+
291+ # Extract agent status for each conversation
292+ agent_statuses = {}
293+ if isinstance (data , list ):
294+ for conversation_data in data :
295+ if isinstance (conversation_data , dict ):
296+ conv_id = conversation_data .get ("id" )
297+ status_str = conversation_data .get ("agent_status" )
298+ if conv_id and status_str :
299+ try :
300+ agent_status = AgentState (status_str )
301+ agent_statuses [conv_id ] = agent_status
302+ except ValueError :
303+ logger .warning (f"Invalid agent status: { status_str } " )
304+
305+ return agent_statuses
306+
307+ except Exception as e :
308+ logger .warning (f"Failed to get agent status from { agent_server_url } : { e } " )
309+ return {}
310+
311+
312+ class SQLSandboxedConversationServiceResolver (SandboxedConversationServiceResolver ):
313+ def get_unsecured_resolver (self ) -> Callable :
314+ from openhands_server .dependency import get_dependency_resolver
315+
316+ sandbox_service_resolver = (
317+ get_dependency_resolver ().sandbox .get_unsecured_resolver ()
318+ )
319+
320+ # Define inline to prevent circular lookup
321+ def resolve_sandboxed_conversation_service (
322+ session : AsyncSession = Depends (async_session_dependency ),
323+ sandbox_service : SandboxService = Depends (sandbox_service_resolver ),
324+ ) -> SandboxedConversationService :
325+ return SQLSandboxedConversationService (session , sandbox_service )
326+
327+ return resolve_sandboxed_conversation_service
328+
329+ def get_resolver_for_user (self ) -> Callable :
330+ from openhands_server .dependency import get_dependency_resolver
331+
332+ sandbox_service_resolver = (
333+ get_dependency_resolver ().sandbox .get_resolver_for_user ()
334+ )
335+
336+ # Define inline to prevent circular lookup
337+ def resolve_sandboxed_conversation_service (
338+ session : AsyncSession = Depends (async_session_dependency ),
339+ sandbox_service : SandboxService = Depends (sandbox_service_resolver ),
340+ ) -> SandboxedConversationService :
341+ service = SQLSandboxedConversationService (session , sandbox_service )
342+ # TODO: Add auth and fix
343+ logger .warning ("⚠️ Using Unsecured SandboxedConversationService!!!" )
344+ # service = ConstrainedSandboxedConversationService(service, self.current_user_id)
345+ return service
346+
347+ return resolve_sandboxed_conversation_service
0 commit comments