Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,4 @@ refactored/code/set_keys.sh
old_code/set_keys.sh
.vscode/settings.json
/static/ta/ba
static/.DS_Store
25 changes: 4 additions & 21 deletions code/python/app-aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,10 @@ async def main():
import core.retriever as retriever
retriever.init()

# Determine which server to use
use_aiohttp = os.environ.get('USE_AIOHTTP', 'true').lower() == 'true'

if use_aiohttp:
print("Starting aiohttp server...")
from webserver.aiohttp_server import AioHTTPServer
server = AioHTTPServer()
await server.start()
else:
print("Starting legacy server...")
from webserver.WebServer import fulfill_request, start_server

# Get port from Azure environment or use default
port = int(os.environ.get('PORT', 8000))

# Start the server
await start_server(
host='0.0.0.0',
port=port,
fulfill_request=fulfill_request
)
print("Starting aiohttp server...")
from webserver.aiohttp_server import AioHTTPServer
server = AioHTTPServer()
await server.start()


if __name__ == "__main__":
Expand Down
22 changes: 10 additions & 12 deletions code/python/app-file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

import asyncio
import os
from webserver.WebServer import fulfill_request, start_server
from dotenv import load_dotenv


def main():
async def main():
# Load environment variables from .env file
load_dotenv()

Expand All @@ -28,6 +27,10 @@ def main():
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
logging.getLogger("azure").setLevel(logging.WARNING)

# Suppress webserver middleware INFO logs
logging.getLogger("webserver.middleware.logging_middleware").setLevel(logging.WARNING)
logging.getLogger("aiohttp.access").setLevel(logging.WARNING)

# Initialize router
import core.router as router
router.init()
Expand All @@ -40,15 +43,10 @@ def main():
import core.retriever as retriever
retriever.init()

# Get port from Azure environment or use default
port = int(os.environ.get('PORT', 8000))

# Start the server
asyncio.run(start_server(
host='0.0.0.0',
port=port,
fulfill_request=fulfill_request
))
print("Starting aiohttp server...")
from webserver.aiohttp_server import AioHTTPServer
server = AioHTTPServer()
await server.start()

if __name__ == "__main__":
main()
asyncio.run(main())
3 changes: 2 additions & 1 deletion code/python/core/baseHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ async def prepare(self):
items = await search(
self.decontextualized_query,
self.site,
query_params=self.query_params
query_params=self.query_params,
handler=self
)
self.final_retrieved_items = items
logger.debug(f"Retrieved {len(items)} items from database")
Expand Down
4 changes: 4 additions & 0 deletions code/python/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class EmbeddingProviderConfig:
@dataclass
class RetrievalProviderConfig:
api_key: Optional[str] = None
api_key_env: Optional[str] = None # Environment variable name for API key
api_endpoint: Optional[str] = None
api_endpoint_env: Optional[str] = None # Environment variable name for endpoint
database_path: Optional[str] = None
index_name: Optional[str] = None
db_type: Optional[str] = None
Expand Down Expand Up @@ -363,7 +365,9 @@ def load_retrieval_config(self, path: str = "config_retrieval.yaml"):
# Use the new method for all configuration values
self.retrieval_endpoints[name] = RetrievalProviderConfig(
api_key=self._get_config_value(cfg.get("api_key_env")),
api_key_env=cfg.get("api_key_env"), # Store the env var name
api_endpoint=self._get_config_value(cfg.get("api_endpoint_env")),
api_endpoint_env=cfg.get("api_endpoint_env"), # Store the env var name
database_path=self._get_config_value(cfg.get("database_path")),
index_name=self._get_config_value(cfg.get("index_name")),
db_type=self._get_config_value(cfg.get("db_type")), # Add db_type
Expand Down
3 changes: 2 additions & 1 deletion code/python/core/fastTrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ async def do(self):
items = await search(
self.handler.query,
self.handler.site,
query_params=self.handler.query_params
query_params=self.handler.query_params,
handler=self.handler
)
self.handler.final_retrieved_items = items
logger.info(f"Fast track retrieved {len(items)} items")
Expand Down
11 changes: 2 additions & 9 deletions code/python/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,16 @@

def init():
"""Initialize LLM providers based on configuration."""
print("=== LLM initialization starting ===")

# Get all configured LLM endpoints
for endpoint_name, endpoint_config in CONFIG.llm_endpoints.items():
llm_type = endpoint_config.llm_type
if llm_type and endpoint_name == CONFIG.preferred_llm_endpoint:
print(f"Preloading preferred LLM provider: {endpoint_name} (type: {llm_type})")
try:
# Use _get_provider which will load and cache the provider
_get_provider(llm_type)
print(f"Successfully loaded {llm_type} provider")
logger.info(f"Successfully loaded {llm_type} provider")
except Exception as e:
print(f"Failed to load {llm_type} provider: {e}")

print("=== LLM initialization complete ===")
logger.warning(f"Failed to load {llm_type} provider: {e}")

# Mapping of LLM types to their required pip packages
_llm_type_packages = {
Expand Down Expand Up @@ -210,7 +205,6 @@ async def ask_llm(
if provider_name not in CONFIG.llm_endpoints:
error_msg = f"Unknown provider '{provider_name}'"
logger.error(error_msg)
print(f"Unknown provider '{provider_name}'")
return {}

# Get provider config using the helper method
Expand Down Expand Up @@ -256,7 +250,6 @@ async def ask_llm(
except Exception as e:
error_msg = f"LLM call failed: {type(e).__name__}: {str(e)}"
logger.error(f"Error with provider {provider_name}: {error_msg}")
print(f"LLM Error ({provider_name}): {type(e).__name__}: {str(e)}")

logger.log_with_context(
LogLevel.ERROR,
Expand Down
82 changes: 82 additions & 0 deletions code/python/core/query_analysis/query_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License

"""
This file is used to rewrite complex queries into simpler keyword queries
for traditional keyword-based search engines.

WARNING: This code is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""

from core.prompts import PromptRunner
import asyncio
from misc.logger.logging_config_helper import get_configured_logger

logger = get_configured_logger("query_rewrite")


class QueryRewrite(PromptRunner):

QUERY_REWRITE_PROMPT_NAME = "QueryRewrite"
STEP_NAME = "QueryRewrite"

def __init__(self, handler):
super().__init__(handler)
self.handler.state.start_precheck_step(self.STEP_NAME)

async def do(self):
"""
Rewrite the decontextualized query into simpler keyword queries.
The results are stored in handler.rewritten_queries.
"""
logger.info(f"Starting query rewrite for: {self.handler.decontextualized_query}")

try:
# Run the query rewrite prompt
response = await self.run_prompt(self.QUERY_REWRITE_PROMPT_NAME, level="high")

if not response:
logger.warning("No response from QueryRewrite prompt, using original query")
self.handler.rewritten_queries = [self.handler.decontextualized_query]
await self.handler.state.precheck_step_done(self.STEP_NAME)
return

# Extract the rewritten queries from the response
rewritten_queries = response.get("rewritten_queries", [])
query_count = response.get("query_count", 0)

# Validate the response
if not rewritten_queries or not isinstance(rewritten_queries, list):
logger.warning("Invalid response from QueryRewrite prompt, using original query")
self.handler.rewritten_queries = [self.handler.decontextualized_query]
else:
# Filter out any empty queries and ensure they are strings
valid_queries = [q for q in rewritten_queries if q and isinstance(q, str) and q.strip()]

if not valid_queries:
logger.warning("No valid rewritten queries, using original query")
self.handler.rewritten_queries = [self.handler.decontextualized_query]
else:
# Limit to 5 queries maximum
self.handler.rewritten_queries = valid_queries[:5]
logger.info(f"Generated {len(self.handler.rewritten_queries)} rewritten queries: {self.handler.rewritten_queries}")

# Send a message to the client about the rewritten queries
if hasattr(self.handler, 'rewritten_queries') and len(self.handler.rewritten_queries) > 1:
message = {
"message_type": "query_rewrite",
"original_query": self.handler.decontextualized_query,
"rewritten_queries": self.handler.rewritten_queries,
"query_id": self.handler.query_id
}
await self.handler.send_message(message)

except Exception as e:
logger.error(f"Error during query rewrite: {e}")
# On error, fall back to using the original query
self.handler.rewritten_queries = [self.handler.decontextualized_query]

finally:
# Always mark the step as done
await self.handler.state.precheck_step_done(self.STEP_NAME)
Loading