|
32 | 32 | from apps.data_training.curd.data_training import get_training_template |
33 | 33 | from apps.datasource.crud.datasource import get_table_schema |
34 | 34 | from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user |
| 35 | +from apps.datasource.embedding.ds_embedding import get_ds_embedding |
35 | 36 | from apps.datasource.models.datasource import CoreDatasource |
36 | 37 | from apps.db.db import exec_sql, get_version, check_connection |
37 | 38 | from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds |
@@ -423,61 +424,65 @@ def select_datasource(self): |
423 | 424 |
|
424 | 425 | full_thinking_text = '' |
425 | 426 | full_text = '' |
426 | | - json_str: Optional[str] = None |
427 | 427 | if not ignore_auto_select: |
428 | | - _ds_list_dict = [] |
429 | | - for _ds in _ds_list: |
430 | | - _ds_list_dict.append(_ds) |
431 | | - datasource_msg.append( |
432 | | - HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) |
433 | | - |
434 | | - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, |
435 | | - ai_modal_id=self.chat_question.ai_modal_id, |
436 | | - ai_modal_name=self.chat_question.ai_modal_name, |
437 | | - operate=OperationEnum.CHOOSE_DATASOURCE, |
438 | | - record_id=self.record.id, |
439 | | - full_message=[{'type': msg.type, |
440 | | - 'content': msg.content} |
441 | | - for |
442 | | - msg in datasource_msg]) |
443 | | - |
444 | | - token_usage = {} |
445 | | - res = self.llm.stream(datasource_msg) |
446 | | - for chunk in res: |
447 | | - SQLBotLogUtil.info(chunk) |
448 | | - reasoning_content_chunk = '' |
449 | | - if 'reasoning_content' in chunk.additional_kwargs: |
450 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
451 | | - # else: |
452 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
453 | | - if reasoning_content_chunk is None: |
| 428 | + if settings.EMBEDDING_ENABLED: |
| 429 | + ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.chat_question.question) |
| 430 | + yield {'content': '{"id":' + str(ds.get('id')) + '}'} |
| 431 | + else: |
| 432 | + _ds_list_dict = [] |
| 433 | + for _ds in _ds_list: |
| 434 | + _ds_list_dict.append(_ds) |
| 435 | + datasource_msg.append( |
| 436 | + HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) |
| 437 | + |
| 438 | + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, |
| 439 | + ai_modal_id=self.chat_question.ai_modal_id, |
| 440 | + ai_modal_name=self.chat_question.ai_modal_name, |
| 441 | + operate=OperationEnum.CHOOSE_DATASOURCE, |
| 442 | + record_id=self.record.id, |
| 443 | + full_message=[{'type': msg.type, |
| 444 | + 'content': msg.content} |
| 445 | + for |
| 446 | + msg in datasource_msg]) |
| 447 | + |
| 448 | + token_usage = {} |
| 449 | + res = self.llm.stream(datasource_msg) |
| 450 | + for chunk in res: |
| 451 | + SQLBotLogUtil.info(chunk) |
454 | 452 | reasoning_content_chunk = '' |
455 | | - full_thinking_text += reasoning_content_chunk |
456 | | - |
457 | | - full_text += chunk.content |
458 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
459 | | - get_token_usage(chunk, token_usage) |
460 | | - datasource_msg.append(AIMessage(full_text)) |
461 | | - |
462 | | - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, |
463 | | - log=self.current_logs[ |
464 | | - OperationEnum.CHOOSE_DATASOURCE], |
465 | | - full_message=[ |
466 | | - {'type': msg.type, |
467 | | - 'content': msg.content} |
468 | | - for msg in datasource_msg], |
469 | | - reasoning_content=full_thinking_text, |
470 | | - token_usage=token_usage) |
471 | | - |
472 | | - json_str = extract_nested_json(full_text) |
473 | | - if json_str is None: |
474 | | - raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}') |
| 453 | + if 'reasoning_content' in chunk.additional_kwargs: |
| 454 | + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
| 455 | + # else: |
| 456 | + # reasoning_content_chunk = chunk.get('reasoning_content') |
| 457 | + if reasoning_content_chunk is None: |
| 458 | + reasoning_content_chunk = '' |
| 459 | + full_thinking_text += reasoning_content_chunk |
| 460 | + |
| 461 | + full_text += chunk.content |
| 462 | + yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
| 463 | + get_token_usage(chunk, token_usage) |
| 464 | + datasource_msg.append(AIMessage(full_text)) |
| 465 | + |
| 466 | + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, |
| 467 | + log=self.current_logs[ |
| 468 | + OperationEnum.CHOOSE_DATASOURCE], |
| 469 | + full_message=[ |
| 470 | + {'type': msg.type, |
| 471 | + 'content': msg.content} |
| 472 | + for msg in datasource_msg], |
| 473 | + reasoning_content=full_thinking_text, |
| 474 | + token_usage=token_usage) |
| 475 | + |
| 476 | + json_str = extract_nested_json(full_text) |
| 477 | + if json_str is None: |
| 478 | + raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}') |
| 479 | + ds = orjson.loads(json_str) |
475 | 480 |
|
476 | 481 | _error: Exception | None = None |
477 | 482 | _datasource: int | None = None |
478 | 483 | _engine_type: str | None = None |
479 | 484 | try: |
480 | | - data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str) |
| 485 | + data: dict = _ds_list[0] if ignore_auto_select else ds |
481 | 486 |
|
482 | 487 | if data.get('id') and data.get('id') != 0: |
483 | 488 | _datasource = data['id'] |
@@ -516,7 +521,7 @@ def select_datasource(self): |
516 | 521 | except Exception as e: |
517 | 522 | _error = e |
518 | 523 |
|
519 | | - if not ignore_auto_select: |
| 524 | + if not ignore_auto_select and not settings.EMBEDDING_ENABLED: |
520 | 525 | self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id, |
521 | 526 | answer=orjson.dumps({'content': full_text}).decode(), |
522 | 527 | datasource=_datasource, |
|
0 commit comments