Skip to content

Commit 261f501

Browse files
authored
Merge pull request #110 from iryna-kondr/gguf
replaced gpt4all with llama-cpp-python
2 parents a9c29c4 + 10c09d2 commit 261f501

File tree

10 files changed

+242
-70
lines changed

10 files changed

+242
-70
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111
"google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0"
1212
]
1313
name = "scikit-llm"
14-
version = "1.3.1"
14+
version = "1.4.0"
1515
authors = [
1616
{ name="Oleh Kostromin", email="[email protected]" },
1717
{ name="Iryna Kondrashchenko", email="[email protected]" },
@@ -27,7 +27,7 @@ classifiers = [
2727
]
2828

2929
[project.optional-dependencies]
30-
gpt4all = ["gpt4all>=2.0.0,<3.0.0"]
30+
gguf = ["llama-cpp-python>=0.2.82,<0.2.83"]
3131
annoy = ["annoy>=1.17.2,<2.0.0"]
3232

3333
[tool.ruff]

skllm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = '1.3.1'
1+
__version__ = '1.4.0'
22
__author__ = 'Iryna Kondrashchenko, Oleh Kostromin'

skllm/config.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
88
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
99
_GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"
10+
_GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH"
11+
_GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS"
12+
_GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE"
1013

1114

1215
class SKLLMConfig:
@@ -169,4 +172,36 @@ def get_gpt_url() -> Optional[str]:
169172
@staticmethod
170173
def reset_gpt_url():
171174
"""Resets the GPT URL."""
172-
os.environ.pop(_GPT_URL_VAR, None)
175+
os.environ.pop(_GPT_URL_VAR, None)
176+
177+
@staticmethod
178+
def get_gguf_download_path() -> str:
179+
"""Gets the path to store the downloaded GGUF files."""
180+
default_path = os.path.join(os.path.expanduser("~"), ".skllm", "gguf")
181+
return os.environ.get(_GGUF_DOWNLOAD_PATH, default_path)
182+
183+
@staticmethod
184+
def get_gguf_max_gpu_layers() -> int:
185+
"""Gets the maximum number of layers to use for the GGUF model."""
186+
return int(os.environ.get(_GGUF_MAX_GPU_LAYERS, 0))
187+
188+
@staticmethod
189+
def set_gguf_max_gpu_layers(n_layers: int):
190+
"""Sets the maximum number of layers to use for the GGUF model."""
191+
if not isinstance(n_layers, int):
192+
raise ValueError("n_layers must be an integer")
193+
if n_layers < -1:
194+
n_layers = -1
195+
os.environ[_GGUF_MAX_GPU_LAYERS] = str(n_layers)
196+
197+
@staticmethod
198+
def set_gguf_verbose(verbose: bool):
199+
"""Sets the verbosity of the GGUF model."""
200+
if not isinstance(verbose, bool):
201+
raise ValueError("verbose must be a boolean")
202+
os.environ[_GGUF_VERBOSE] = str(verbose)
203+
204+
@staticmethod
205+
def get_gguf_verbose() -> bool:
206+
"""Gets the verbosity of the GGUF model."""
207+
return os.environ.get(_GGUF_VERBOSE, "False").lower() == "true"

skllm/llm/gpt/clients/gpt4all/completion.py

Lines changed: 0 additions & 54 deletions
This file was deleted.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from skllm.llm.gpt.clients.llama_cpp.handler import ModelCache, LlamaHandler
2+
3+
4+
def get_chat_completion(messages: dict, model: str, **kwargs):
5+
6+
with ModelCache.lock:
7+
handler = ModelCache.get(model)
8+
if handler is None:
9+
handler = LlamaHandler(model)
10+
ModelCache.store(model, handler)
11+
12+
return handler.get_chat_completion(messages, **kwargs)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import threading
2+
import os
3+
import hashlib
4+
import requests
5+
from tqdm import tqdm
6+
import hashlib
7+
from typing import Optional
8+
import tempfile
9+
from skllm.config import SKLLMConfig
10+
from warnings import warn
11+
12+
13+
try:
14+
from llama_cpp import Llama as _Llama
15+
16+
_llama_imported = True
17+
except (ImportError, ModuleNotFoundError):
18+
_llama_imported = False
19+
20+
21+
supported_models = {
22+
"llama3-8b-q4": {
23+
"download_url": "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf",
24+
"sha256": "c57380038ea85d8bec586ec2af9c91abc2f2b332d41d6cf180581d7bdffb93c1",
25+
"n_ctx": 8192,
26+
"supports_system_message": True,
27+
},
28+
"gemma2-9b-q4": {
29+
"download_url": "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf",
30+
"sha256": "13b2a7b4115bbd0900162edcebe476da1ba1fc24e718e8b40d32f6e300f56dfe",
31+
"n_ctx": 8192,
32+
"supports_system_message": False,
33+
},
34+
"phi3-mini-q4": {
35+
"download_url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf",
36+
"sha256": "8a83c7fb9049a9b2e92266fa7ad04933bb53aa1e85136b7b30f1b8000ff2edef",
37+
"n_ctx": 4096,
38+
"supports_system_message": False,
39+
},
40+
"mistral0.3-7b-q4": {
41+
"download_url": "https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf",
42+
"sha256": "1270d22c0fbb3d092fb725d4d96c457b7b687a5f5a715abe1e818da303e562b6",
43+
"n_ctx": 32768,
44+
"supports_system_message": False,
45+
},
46+
"gemma2-2b-q6": {
47+
"download_url": "https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q6_K_L.gguf",
48+
"sha256": "b2ef9f67b38c6e246e593cdb9739e34043d84549755a1057d402563a78ff2254",
49+
"n_ctx": 8192,
50+
"supports_system_message": False,
51+
},
52+
}
53+
54+
55+
class LlamaHandler:
56+
57+
def maybe_download_model(self, model_name, download_url, sha256) -> str:
58+
download_folder = SKLLMConfig.get_gguf_download_path()
59+
os.makedirs(download_folder, exist_ok=True)
60+
model_name = model_name + ".gguf"
61+
model_path = os.path.join(download_folder, model_name)
62+
if not os.path.exists(model_path):
63+
print("The model `{0}` is not found locally.".format(model_name))
64+
self._download_model(model_name, download_folder, download_url, sha256)
65+
return model_path
66+
67+
def _download_model(
68+
self, model_filename: str, model_path: str, url: str, expected_sha256: str
69+
) -> str:
70+
full_path = os.path.join(model_path, model_filename)
71+
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=model_path)
72+
temp_path = temp_file.name
73+
temp_file.close()
74+
75+
response = requests.get(url, stream=True)
76+
77+
if response.status_code != 200:
78+
os.remove(temp_path)
79+
raise ValueError(
80+
f"Request failed: HTTP {response.status_code} {response.reason}"
81+
)
82+
83+
total_size_in_bytes = int(response.headers.get("content-length", 0))
84+
block_size = 1024 * 1024 * 4
85+
86+
sha256 = hashlib.sha256()
87+
88+
with (
89+
open(temp_path, "wb") as file,
90+
tqdm(
91+
desc="Downloading {0}: ".format(model_filename),
92+
total=total_size_in_bytes,
93+
unit="iB",
94+
unit_scale=True,
95+
) as progress_bar,
96+
):
97+
for data in response.iter_content(block_size):
98+
file.write(data)
99+
sha256.update(data)
100+
progress_bar.update(len(data))
101+
102+
downloaded_sha256 = sha256.hexdigest()
103+
if downloaded_sha256 != expected_sha256:
104+
raise ValueError(
105+
f"Expected SHA-256 hash {expected_sha256}, but got {downloaded_sha256}"
106+
)
107+
108+
os.rename(temp_path, full_path)
109+
110+
def __init__(self, model: str):
111+
if not _llama_imported:
112+
raise ImportError(
113+
"llama_cpp is not installed, try `pip install scikit-llm[llama_cpp]`"
114+
)
115+
self.lock = threading.Lock()
116+
if model not in supported_models:
117+
raise ValueError(f"Model {model} is not supported.")
118+
download_url = supported_models[model]["download_url"]
119+
sha256 = supported_models[model]["sha256"]
120+
n_ctx = supported_models[model]["n_ctx"]
121+
self.supports_system_message = supported_models[model][
122+
"supports_system_message"
123+
]
124+
if not self.supports_system_message:
125+
warn(
126+
f"The model {model} does not support system messages. This may cause issues with some estimators."
127+
)
128+
extended_model_name = model + "-" + sha256[:8]
129+
model_path = self.maybe_download_model(
130+
extended_model_name, download_url, sha256
131+
)
132+
max_gpu_layers = SKLLMConfig.get_gguf_max_gpu_layers()
133+
verbose = SKLLMConfig.get_gguf_verbose()
134+
self.model = _Llama(
135+
model_path=model_path,
136+
n_ctx=n_ctx,
137+
verbose=verbose,
138+
n_gpu_layers=max_gpu_layers,
139+
)
140+
141+
def get_chat_completion(self, messages: dict, **kwargs):
142+
if not self.supports_system_message:
143+
messages = [m for m in messages if m["role"] != "system"]
144+
with self.lock:
145+
return self.model.create_chat_completion(
146+
messages, temperature=0.0, **kwargs
147+
)
148+
149+
150+
class ModelCache:
151+
lock = threading.Lock()
152+
cache: dict[str, LlamaHandler] = {}
153+
154+
@classmethod
155+
def get(cls, key) -> Optional[LlamaHandler]:
156+
return cls.cache.get(key, None)
157+
158+
@classmethod
159+
def store(cls, key, value):
160+
cls.cache[key] = value
161+
162+
@classmethod
163+
def clear(cls):
164+
cls.cache = {}

skllm/llm/gpt/completion.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from skllm.llm.gpt.clients.openai.completion import (
33
get_chat_completion as _oai_get_chat_completion,
44
)
5-
from skllm.llm.gpt.clients.gpt4all.completion import (
6-
get_chat_completion as _g4a_get_chat_completion,
5+
from skllm.llm.gpt.clients.llama_cpp.completion import (
6+
get_chat_completion as _llamacpp_get_chat_completion,
77
)
88
from skllm.llm.gpt.utils import split_to_api_and_model
99
from skllm.config import SKLLMConfig as _Config
1010

11+
1112
def get_chat_completion(
1213
messages: dict,
1314
openai_key: str = None,
@@ -17,14 +18,18 @@ def get_chat_completion(
1718
):
1819
"""Gets a chat completion from the OpenAI compatible API."""
1920
api, model = split_to_api_and_model(model)
20-
if api == "gpt4all":
21-
return _g4a_get_chat_completion(messages, model)
21+
if api == "gguf":
22+
return _llamacpp_get_chat_completion(messages, model)
2223
else:
2324
url = _Config.get_gpt_url()
2425
if api == "openai" and url is not None:
25-
warnings.warn(f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`.")
26+
warnings.warn(
27+
f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`."
28+
)
2629
elif api == "custom_url" and url is None:
27-
raise ValueError("You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url(<url>)`.")
30+
raise ValueError(
31+
"You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url(<url>)`."
32+
)
2833
return _oai_get_chat_completion(
2934
messages,
3035
openai_key,

skllm/llm/gpt/mixin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def _get_openai_key(self) -> str:
8181
key = self.key
8282
if key is None:
8383
key = _Config.get_openai_key()
84+
if (
85+
hasattr(self, "model")
86+
and isinstance(self.model, str)
87+
and self.model.startswith("gguf::")
88+
):
89+
key = "gguf"
8490
if key is None:
8591
raise RuntimeError("OpenAI key was not found")
8692
return key

skllm/llm/gpt/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Tuple
22

3-
SUPPORTED_APIS = ["openai", "azure", "gpt4all", "custom_url"]
3+
SUPPORTED_APIS = ["openai", "azure", "gguf", "custom_url"]
44

55

66
def split_to_api_and_model(model: str) -> Tuple[str, str]:
@@ -9,4 +9,4 @@ def split_to_api_and_model(model: str) -> Tuple[str, str]:
99
for api in SUPPORTED_APIS:
1010
if model.startswith(f"{api}::"):
1111
return api, model[len(api) + 2 :]
12-
raise ValueError(f"Unsupported API: {model.split('::')[0]}")
12+
raise ValueError(f"Unsupported API: {model.split('::')[0]}")

skllm/models/_base/classifier.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,14 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int =
239239
warnings.warn(
240240
"Passing num_workers to predict is temporary and will be removed in the future."
241241
)
242-
with ThreadPoolExecutor(max_workers=num_workers) as executor:
243-
predictions = list(
244-
tqdm(executor.map(self._predict_single, X), total=len(X))
245-
)
242+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
243+
predictions = list(
244+
tqdm(executor.map(self._predict_single, X), total=len(X))
245+
)
246+
else:
247+
predictions = []
248+
for x in tqdm(X):
249+
predictions.append(self._predict_single(x))
246250

247251
return np.array(predictions)
248252

0 commit comments

Comments
 (0)