Skip to content

Commit 919378f

Browse files
committed
feat: Vector retrieval matches datasource
1 parent a426c39 commit 919378f

File tree

2 files changed

+57
-51
lines changed

2 files changed

+57
-51
lines changed

backend/apps/chat/task/llm.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from apps.data_training.curd.data_training import get_training_template
3333
from apps.datasource.crud.datasource import get_table_schema
3434
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
35+
from apps.datasource.embedding.ds_embedding import get_ds_embedding
3536
from apps.datasource.models.datasource import CoreDatasource
3637
from apps.db.db import exec_sql, get_version, check_connection
3738
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
@@ -423,61 +424,65 @@ def select_datasource(self):
423424

424425
full_thinking_text = ''
425426
full_text = ''
426-
json_str: Optional[str] = None
427427
if not ignore_auto_select:
428-
_ds_list_dict = []
429-
for _ds in _ds_list:
430-
_ds_list_dict.append(_ds)
431-
datasource_msg.append(
432-
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
433-
434-
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
435-
ai_modal_id=self.chat_question.ai_modal_id,
436-
ai_modal_name=self.chat_question.ai_modal_name,
437-
operate=OperationEnum.CHOOSE_DATASOURCE,
438-
record_id=self.record.id,
439-
full_message=[{'type': msg.type,
440-
'content': msg.content}
441-
for
442-
msg in datasource_msg])
443-
444-
token_usage = {}
445-
res = self.llm.stream(datasource_msg)
446-
for chunk in res:
447-
SQLBotLogUtil.info(chunk)
448-
reasoning_content_chunk = ''
449-
if 'reasoning_content' in chunk.additional_kwargs:
450-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
451-
# else:
452-
# reasoning_content_chunk = chunk.get('reasoning_content')
453-
if reasoning_content_chunk is None:
428+
if settings.EMBEDDING_ENABLED:
429+
ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.chat_question.question)
430+
yield {'content': '{"id":' + str(ds.get('id')) + '}'}
431+
else:
432+
_ds_list_dict = []
433+
for _ds in _ds_list:
434+
_ds_list_dict.append(_ds)
435+
datasource_msg.append(
436+
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
437+
438+
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
439+
ai_modal_id=self.chat_question.ai_modal_id,
440+
ai_modal_name=self.chat_question.ai_modal_name,
441+
operate=OperationEnum.CHOOSE_DATASOURCE,
442+
record_id=self.record.id,
443+
full_message=[{'type': msg.type,
444+
'content': msg.content}
445+
for
446+
msg in datasource_msg])
447+
448+
token_usage = {}
449+
res = self.llm.stream(datasource_msg)
450+
for chunk in res:
451+
SQLBotLogUtil.info(chunk)
454452
reasoning_content_chunk = ''
455-
full_thinking_text += reasoning_content_chunk
456-
457-
full_text += chunk.content
458-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
459-
get_token_usage(chunk, token_usage)
460-
datasource_msg.append(AIMessage(full_text))
461-
462-
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
463-
log=self.current_logs[
464-
OperationEnum.CHOOSE_DATASOURCE],
465-
full_message=[
466-
{'type': msg.type,
467-
'content': msg.content}
468-
for msg in datasource_msg],
469-
reasoning_content=full_thinking_text,
470-
token_usage=token_usage)
471-
472-
json_str = extract_nested_json(full_text)
473-
if json_str is None:
474-
raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}')
453+
if 'reasoning_content' in chunk.additional_kwargs:
454+
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
455+
# else:
456+
# reasoning_content_chunk = chunk.get('reasoning_content')
457+
if reasoning_content_chunk is None:
458+
reasoning_content_chunk = ''
459+
full_thinking_text += reasoning_content_chunk
460+
461+
full_text += chunk.content
462+
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
463+
get_token_usage(chunk, token_usage)
464+
datasource_msg.append(AIMessage(full_text))
465+
466+
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
467+
log=self.current_logs[
468+
OperationEnum.CHOOSE_DATASOURCE],
469+
full_message=[
470+
{'type': msg.type,
471+
'content': msg.content}
472+
for msg in datasource_msg],
473+
reasoning_content=full_thinking_text,
474+
token_usage=token_usage)
475+
476+
json_str = extract_nested_json(full_text)
477+
if json_str is None:
478+
raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}')
479+
ds = orjson.loads(json_str)
475480

476481
_error: Exception | None = None
477482
_datasource: int | None = None
478483
_engine_type: str | None = None
479484
try:
480-
data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str)
485+
data: dict = _ds_list[0] if ignore_auto_select else ds
481486

482487
if data.get('id') and data.get('id') != 0:
483488
_datasource = data['id']
@@ -516,7 +521,7 @@ def select_datasource(self):
516521
except Exception as e:
517522
_error = e
518523

519-
if not ignore_auto_select:
524+
if not ignore_auto_select and not settings.EMBEDDING_ENABLED:
520525
self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id,
521526
answer=orjson.dumps({'content': full_text}).decode(),
522527
datasource=_datasource,

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import traceback
66

77
from apps.ai_model.embedding import EmbeddingModelCache
8-
from apps.datasource.crud.datasource import get_table_schema, get_ds
8+
from apps.datasource.crud.datasource import get_table_schema
9+
from apps.datasource.models.datasource import CoreDatasource
910
from common.core.deps import SessionDep, CurrentUser
1011

1112

@@ -28,7 +29,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, q
2829
_list = []
2930
for _ds in _ds_list:
3031
if _ds.get('id'):
31-
ds = get_ds(session, _ds.get('id'))
32+
ds = session.get(CoreDatasource, _ds.get('id'))
3233

3334
table_schema = get_table_schema(session, current_user, ds)
3435
ds_info = f"{ds.name}, {ds.description}\n"

0 commit comments

Comments
 (0)