Skip to content

Commit bc92875

Browse files
authored
Merge pull request #9 from dingzhaohan/main
fix job add, tiefblue upload interface
2 parents 37bbd97 + 3b01383 commit bc92875

File tree

29 files changed

+2776
-53
lines changed

29 files changed

+2776
-53
lines changed

pkg/sumdb/sum.golang.org/latest

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
go.sum database tree
2+
39073249
3+
CHoGLc3xiTLWJIkUvqwtoUC8uUi8tsi/AMKPaA1nDxo=
4+
5+
— sum.golang.org Az3grieYv97xelotCMbg3+3zXebCIZkZfor40fTCTWo4rNjv+mQ8eEJFgiyyjPkL1h6GWziyoBzovPY6pFcfbssUNQU=

setup.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def read_file(filename):
1010

1111

1212
setup(
13-
name="bohrium",
14-
version="0.1.0",
13+
name="bohrium-sdk",
14+
version="0.13.0",
1515
author="dingzhaohan",
1616
author_email="[email protected]",
1717
url="https://github.com/dingzhaohan",
@@ -24,7 +24,15 @@ def read_file(filename):
2424
# include .txt all of them
2525
"": ["*.txt"]
2626
},
27-
install_requires=[],
27+
install_requires=[
28+
"distro",
29+
"httpx",
30+
"typing_extensions",
31+
"anyio",
32+
"pyhumps",
33+
"requests",
34+
"tqdm",
35+
],
2836
python_requires=">=3.7",
2937
entry_points={},
3038
)

src/bohrium/_base_client.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def _build_headers(self, custom_headers) -> httpx.Headers:
143143
headers_dict = _merge_mappings(
144144
self.default_headers, self._custom_headers, custom_headers
145145
)
146+
# 过滤掉 value 为 None 的 header
147+
headers_dict = {k: v for k, v in headers_dict.items() if v is not None}
146148
headers = httpx.Headers(headers_dict)
147149
return headers or dict()
148150

@@ -177,17 +179,45 @@ def platform_headers(self) -> Dict[str, str]:
177179
exceptions=(httpx.RequestError,),
178180
)
179181
def _request(
180-
self, method: str, path: str, json=None, headers=None, **kwargs
182+
self, method: str, path: str, json=None, headers=None, data=None, **kwargs
181183
) -> httpx.Response:
182184
url = urljoin(str(self._base_url), path)
183185
logger.info(f"Requesting {method} {url}")
184186
merged_headers = self._build_headers(headers)
185187
merged_params = self._build_params(kwargs.get("params"))
188+
189+
# 处理文件上传
190+
request_kwargs = {
191+
"method": method.upper(),
192+
"url": url,
193+
"params": merged_params,
194+
}
195+
196+
# 处理超时参数
197+
if "timeout" in kwargs:
198+
request_kwargs["timeout"] = kwargs["timeout"]
199+
200+
if json is not None:
201+
request_kwargs["json"] = json
202+
request_kwargs["headers"] = merged_headers
203+
elif "files" in kwargs:
204+
# 当有files参数时,不使用json参数,而是使用files和data
205+
# 不设置headers,让httpx自动处理multipart/form-data
206+
request_kwargs["files"] = kwargs["files"]
207+
if "data" in kwargs:
208+
request_kwargs["data"] = kwargs["data"]
209+
elif "data" in kwargs:
210+
request_kwargs["data"] = kwargs["data"]
211+
request_kwargs["headers"] = merged_headers
212+
else:
213+
request_kwargs["headers"] = merged_headers
214+
186215
try:
187216
return self._client.request(
188217
method.upper(),
189218
url,
190219
json=json,
220+
data=data,
191221
headers=merged_headers,
192222
params=merged_params,
193223
)

src/bohrium/_client.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414

1515
class Bohrium(SyncAPIClient):
1616
job: resources.Job
17+
sigma_search: resources.SigmaSearch
18+
uni_parser: resources.UniParser
19+
knowledge_base: resources.KnowledgeBase
20+
paper: resources.Paper
1721

1822
# client options
1923
access_key: str
20-
project_id: Union[str, None]
24+
project_id: Optional[str]
2125

2226
def __init__(
2327
self,
2428
access_key: Optional[str] = None,
29+
app_key: Optional[str] = None,
2530
base_url: Optional[Union[str, URL]] = None,
2631
project_id: Optional[str] = None,
2732
timeout: Optional[Union[float, Timeout]] = 30.0,
@@ -36,15 +41,11 @@ def __init__(
3641
"The api_key client option must be set either by passing api_key to the client or by setting the ACCESS_KEY environment variable"
3742
)
3843
self.access_key = access_key
39-
44+
self.app_key = app_key or os.environ.get("BOHRIUM_APP_KEY")
45+
self.params = {"accessKey": self.access_key}
4046
if project_id is None:
4147
project_id = os.environ.get("BOHRIUM_PROJECT_ID")
4248

43-
if project_id is None:
44-
raise BohriumError(
45-
"The project_id client option must be set either by passing project_id to the client or by setting the BOHRIUM_PROJECT_ID environment variable"
46-
)
47-
4849
self.project_id = project_id
4950

5051
if base_url is None:
@@ -63,6 +64,10 @@ def __init__(
6364
)
6465

6566
self.job = resources.Job(self)
67+
self.sigma_search = resources.SigmaSearch(self)
68+
self.uni_parser = resources.UniParser(self)
69+
self.knowledge_base = resources.KnowledgeBase(self)
70+
self.paper = resources.Paper(self)
6671

6772
@property
6873
@override
@@ -71,6 +76,7 @@ def default_headers(self) -> dict[str, str]:
7176
"Accept": "application/json",
7277
"Content-Type": "application/json",
7378
"Authorization": f"Bearer {self.access_key}",
79+
"x-app-key": self.app_key,
7480
}
7581

7682
def _make_status_error(

src/bohrium/resources/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
11

22
from .job import Job, AsyncJob
3+
from .sigma_search import SigmaSearch, AsyncSigmaSearch
4+
from .uni_parser import UniParser, AsyncUniParser
5+
from .knowledge_base import KnowledgeBase, AsyncKnowledgeBase
6+
from .paper import Paper, AsyncPaper
7+
from .tiefblue import Tiefblue
8+
9+
__all__ = [
10+
"Job", "AsyncJob", "Tiefblue"
11+
"SigmaSearch", "AsyncSigmaSearch",
12+
"UniParser", "AsyncUniParser",
13+
"KnowledgeBase", "AsyncKnowledgeBase",
14+
"Paper", "AsyncPaper"
15+
]
316

4-
__all__ = ["Job", "AsyncJob"]

src/bohrium/resources/job/job.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import uuid
44
from pathlib import Path
5-
5+
import humps
66
# from ..._resource import BaseClient
77
from pprint import pprint
88
from typing import Optional
@@ -18,6 +18,21 @@
1818

1919
class Job(SyncAPIResource):
2020

21+
def create(self, project_id, name='', group_id=0):
22+
data = {
23+
'projectId': project_id
24+
}
25+
if name:
26+
data['name'] = name
27+
if group_id:
28+
data['bohrGroupId'] = group_id
29+
try:
30+
data = self._client.post(f'/openapi/v1/job/create', json=data, params=self._client.params)
31+
data = data.json()
32+
except Exception as e:
33+
raise e
34+
return data.get("data", {})
35+
2136
def detail(self, job_id):
2237
log.info(f"detail job {job_id}")
2338
response = self._client.get(f"/openapi/v1/job/{job_id}")
@@ -43,7 +58,6 @@ def submit(
4358
):
4459
# log.info(f"submit job {name},project_id:{project_id}")
4560
data = self.create_job(project_id, job_name, job_group_id)
46-
print(data)
4761
if work_dir != "":
4862
if not os.path.exists(work_dir):
4963
raise FileNotFoundError
@@ -73,38 +87,39 @@ def submit(
7387
)
7488
return self.insert(job_add_request.to_dict())
7589

76-
def insert(self, data):
77-
# log.info(f"insert job {data}")
78-
response = self._client.post("/openapi/v2/job/add", json=data)
79-
pprint(response.request)
80-
print(response.json())
90+
def insert(self, **kwargs):
91+
camel_data = {humps.camelize(k): v for k, v in kwargs.items()}
92+
if not isinstance(camel_data['ossPath'], list):
93+
camel_data['ossPath'] = [camel_data['ossPath']]
94+
if 'logFile' in camel_data:
95+
camel_data['logFiles'] = camel_data['logFile']
96+
if 'logFiles' in camel_data and not isinstance(camel_data['logFiles'], list):
97+
camel_data['logFiles'] = [camel_data['logFiles']]
98+
response = self._client.post("/openapi/v2/job/add", json=camel_data)
99+
return response.json().get("data")
81100

82101
def delete(self, job_id):
83102
# log.info(f"delete job {job_id}")
84103
response = self._client.post(f"/openapi/v1/job/del/{job_id}")
85-
pprint(response.request)
86-
print(response.json())
104+
87105

88106
def terminate(self, job_id):
89107
# log.info(f"terminate job {job_id}")
90108
response = self._client.post(f"/openapi/v1/job/terminate/{job_id}")
91-
pprint(response.request)
92-
print(response.json())
109+
93110

94111
def kill(self, job_id):
95112
# log.info(f"kill job {job_id}")
96113
response = self._client.post(f"/openapi/v1/job/kill/{job_id}")
97-
pprint(response.request)
98-
print(response.json())
114+
99115

100116
def log(self, job_id, log_file="STDOUTERR", page=-1, page_size=8192):
101117
# log.info(f"log job {job_id}")
102118
response = self._client.get(
103119
f"/openapi/v1/job/{job_id}/log",
104120
params={"logFile": log_file, "page": page, "pageSize": page_size},
105121
)
106-
pprint(response.request)
107-
print(response.json().get("data")["log"])
122+
108123
return response.json().get("data")["log"]
109124

110125
def create_job(
@@ -132,8 +147,7 @@ def create_job(
132147
"bohrGroupId": group_id,
133148
}
134149
response = self._client.post(f"/openapi/v1/job/create", json=data)
135-
pprint(response.request)
136-
print(response.json())
150+
137151
return response.json().get("data")
138152

139153
def create_job_group(self, project_id, job_group_name):
@@ -142,9 +156,8 @@ def create_job_group(self, project_id, job_group_name):
142156
"/openapi/v1/job_group/add",
143157
json={"name": job_group_name, "projectId": project_id},
144158
)
145-
pprint(response.request)
146-
print(response.json())
147-
159+
return response.json().get("data")
160+
148161
def upload(
149162
self,
150163
file_path: str,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .knowledge_base import KnowledgeBase, AsyncKnowledgeBase
2+
3+
__all__ = ["KnowledgeBase", "AsyncKnowledgeBase"]
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import logging
2+
from typing import Optional, List, Dict, Any, Union
3+
from pprint import pprint
4+
5+
from ..._resource import AsyncAPIResource, SyncAPIResource
6+
from ..._response import APIResponse
7+
from ...types.knowledge_base.knowledge_base import (
8+
HybridRecallRequest,
9+
PaperRecallRequest,
10+
PaperInfo,
11+
ChunkSearchRequest
12+
)
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
class KnowledgeBase(SyncAPIResource):
18+
"""知识库相关接口"""
19+
20+
def hybrid_recall(
21+
self,
22+
knowledge_base_id: int,
23+
text: str,
24+
k: int = 200,
25+
keywords: Optional[Dict[str, float]] = None,
26+
**kwargs
27+
):
28+
"""知识库混合召回"""
29+
log.info(f"hybrid recall from knowledge base: {knowledge_base_id}")
30+
31+
data = {
32+
"knowledge_base_id": knowledge_base_id,
33+
"text": text,
34+
"k": k
35+
}
36+
37+
if keywords:
38+
data["keywords"] = keywords
39+
if kwargs:
40+
data.update(kwargs)
41+
42+
response = self._client.post("/openapi/v1/knowledge/recall/hybrid", json=data)
43+
log.info(response.json())
44+
return APIResponse(response).json.get("data")
45+
46+
def paper_recall(
47+
self,
48+
text: str,
49+
k: int,
50+
papers: List[Dict[str, str]],
51+
**kwargs
52+
):
53+
"""单篇论文召回"""
54+
log.info(f"paper recall: {len(papers)} papers")
55+
56+
data = {
57+
"text": text,
58+
"k": k,
59+
"papers": papers
60+
}
61+
62+
if kwargs:
63+
data.update(kwargs)
64+
65+
response = self._client.post("/openapi/v1/knowledge/recall/papers", json=data)
66+
log.info(response.json())
67+
return APIResponse(response).json.get("data")
68+
69+
def get_file_tree(
70+
self,
71+
folder_id: str,
72+
**kwargs
73+
):
74+
"""获取单篇切片文件树"""
75+
log.info(f"get file tree for folder: {folder_id}")
76+
77+
params = {"folderId": folder_id}
78+
if kwargs:
79+
params.update(kwargs)
80+
81+
response = self._client.get("/openapi/v1/knowledge/folder/file_tree", params=params)
82+
log.info(response.json())
83+
return APIResponse(response).json.get("data")
84+
85+
def search_by_md5_paper_id(
86+
self,
87+
md5: str,
88+
paper_id: str = "",
89+
page_num: int = 0,
90+
page_size: int = 9999,
91+
**kwargs
92+
):
93+
"""根据md5和paper_id搜索chunk信息"""
94+
log.info(f"search chunk by md5: {md5}, paper_id: {paper_id}")
95+
96+
data = {
97+
"md5": md5,
98+
"paper_id": paper_id,
99+
"page_num": page_num,
100+
"page_size": page_size
101+
}
102+
103+
if kwargs:
104+
data.update(kwargs)
105+
106+
response = self._client.post("/openapi/v1/knowledge/box/search_by_md5_paper_id", json=data)
107+
log.info(response.json())
108+
return APIResponse(response).json.get("data")
109+
110+
111+
112+
class AsyncKnowledgeBase(AsyncAPIResource):
113+
"""异步知识库相关接口"""
114+
pass

0 commit comments

Comments
 (0)