|
| 1 | +import threading |
| 2 | +import os |
| 3 | +import hashlib |
| 4 | +import requests |
| 5 | +from tqdm import tqdm |
| 6 | +import hashlib |
| 7 | +from typing import Optional |
| 8 | +import tempfile |
| 9 | +from skllm.config import SKLLMConfig |
| 10 | +from warnings import warn |
| 11 | + |
| 12 | + |
| 13 | +try: |
| 14 | + from llama_cpp import Llama as _Llama |
| 15 | + |
| 16 | + _llama_imported = True |
| 17 | +except (ImportError, ModuleNotFoundError): |
| 18 | + _llama_imported = False |
| 19 | + |
| 20 | + |
| 21 | +supported_models = { |
| 22 | + "llama3-8b-q4": { |
| 23 | + "download_url": "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf", |
| 24 | + "sha256": "c57380038ea85d8bec586ec2af9c91abc2f2b332d41d6cf180581d7bdffb93c1", |
| 25 | + "n_ctx": 8192, |
| 26 | + "supports_system_message": True, |
| 27 | + }, |
| 28 | + "gemma2-9b-q4": { |
| 29 | + "download_url": "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf", |
| 30 | + "sha256": "13b2a7b4115bbd0900162edcebe476da1ba1fc24e718e8b40d32f6e300f56dfe", |
| 31 | + "n_ctx": 8192, |
| 32 | + "supports_system_message": False, |
| 33 | + }, |
| 34 | + "phi3-mini-q4": { |
| 35 | + "download_url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf", |
| 36 | + "sha256": "8a83c7fb9049a9b2e92266fa7ad04933bb53aa1e85136b7b30f1b8000ff2edef", |
| 37 | + "n_ctx": 4096, |
| 38 | + "supports_system_message": False, |
| 39 | + }, |
| 40 | + "mistral0.3-7b-q4": { |
| 41 | + "download_url": "https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf", |
| 42 | + "sha256": "1270d22c0fbb3d092fb725d4d96c457b7b687a5f5a715abe1e818da303e562b6", |
| 43 | + "n_ctx": 32768, |
| 44 | + "supports_system_message": False, |
| 45 | + }, |
| 46 | + "gemma2-2b-q6": { |
| 47 | + "download_url": "https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q6_K_L.gguf", |
| 48 | + "sha256": "b2ef9f67b38c6e246e593cdb9739e34043d84549755a1057d402563a78ff2254", |
| 49 | + "n_ctx": 8192, |
| 50 | + "supports_system_message": False, |
| 51 | + }, |
| 52 | +} |
| 53 | + |
| 54 | + |
| 55 | +class LlamaHandler: |
| 56 | + |
| 57 | + def maybe_download_model(self, model_name, download_url, sha256) -> str: |
| 58 | + download_folder = SKLLMConfig.get_gguf_download_path() |
| 59 | + os.makedirs(download_folder, exist_ok=True) |
| 60 | + model_name = model_name + ".gguf" |
| 61 | + model_path = os.path.join(download_folder, model_name) |
| 62 | + if not os.path.exists(model_path): |
| 63 | + print("The model `{0}` is not found locally.".format(model_name)) |
| 64 | + self._download_model(model_name, download_folder, download_url, sha256) |
| 65 | + return model_path |
| 66 | + |
| 67 | + def _download_model( |
| 68 | + self, model_filename: str, model_path: str, url: str, expected_sha256: str |
| 69 | + ) -> str: |
| 70 | + full_path = os.path.join(model_path, model_filename) |
| 71 | + temp_file = tempfile.NamedTemporaryFile(delete=False, dir=model_path) |
| 72 | + temp_path = temp_file.name |
| 73 | + temp_file.close() |
| 74 | + |
| 75 | + response = requests.get(url, stream=True) |
| 76 | + |
| 77 | + if response.status_code != 200: |
| 78 | + os.remove(temp_path) |
| 79 | + raise ValueError( |
| 80 | + f"Request failed: HTTP {response.status_code} {response.reason}" |
| 81 | + ) |
| 82 | + |
| 83 | + total_size_in_bytes = int(response.headers.get("content-length", 0)) |
| 84 | + block_size = 1024 * 1024 * 4 |
| 85 | + |
| 86 | + sha256 = hashlib.sha256() |
| 87 | + |
| 88 | + with ( |
| 89 | + open(temp_path, "wb") as file, |
| 90 | + tqdm( |
| 91 | + desc="Downloading {0}: ".format(model_filename), |
| 92 | + total=total_size_in_bytes, |
| 93 | + unit="iB", |
| 94 | + unit_scale=True, |
| 95 | + ) as progress_bar, |
| 96 | + ): |
| 97 | + for data in response.iter_content(block_size): |
| 98 | + file.write(data) |
| 99 | + sha256.update(data) |
| 100 | + progress_bar.update(len(data)) |
| 101 | + |
| 102 | + downloaded_sha256 = sha256.hexdigest() |
| 103 | + if downloaded_sha256 != expected_sha256: |
| 104 | + raise ValueError( |
| 105 | + f"Expected SHA-256 hash {expected_sha256}, but got {downloaded_sha256}" |
| 106 | + ) |
| 107 | + |
| 108 | + os.rename(temp_path, full_path) |
| 109 | + |
| 110 | + def __init__(self, model: str): |
| 111 | + if not _llama_imported: |
| 112 | + raise ImportError( |
| 113 | + "llama_cpp is not installed, try `pip install scikit-llm[llama_cpp]`" |
| 114 | + ) |
| 115 | + self.lock = threading.Lock() |
| 116 | + if model not in supported_models: |
| 117 | + raise ValueError(f"Model {model} is not supported.") |
| 118 | + download_url = supported_models[model]["download_url"] |
| 119 | + sha256 = supported_models[model]["sha256"] |
| 120 | + n_ctx = supported_models[model]["n_ctx"] |
| 121 | + self.supports_system_message = supported_models[model][ |
| 122 | + "supports_system_message" |
| 123 | + ] |
| 124 | + if not self.supports_system_message: |
| 125 | + warn( |
| 126 | + f"The model {model} does not support system messages. This may cause issues with some estimators." |
| 127 | + ) |
| 128 | + extended_model_name = model + "-" + sha256[:8] |
| 129 | + model_path = self.maybe_download_model( |
| 130 | + extended_model_name, download_url, sha256 |
| 131 | + ) |
| 132 | + max_gpu_layers = SKLLMConfig.get_gguf_max_gpu_layers() |
| 133 | + verbose = SKLLMConfig.get_gguf_verbose() |
| 134 | + self.model = _Llama( |
| 135 | + model_path=model_path, |
| 136 | + n_ctx=n_ctx, |
| 137 | + verbose=verbose, |
| 138 | + n_gpu_layers=max_gpu_layers, |
| 139 | + ) |
| 140 | + |
| 141 | + def get_chat_completion(self, messages: dict, **kwargs): |
| 142 | + if not self.supports_system_message: |
| 143 | + messages = [m for m in messages if m["role"] != "system"] |
| 144 | + with self.lock: |
| 145 | + return self.model.create_chat_completion( |
| 146 | + messages, temperature=0.0, **kwargs |
| 147 | + ) |
| 148 | + |
| 149 | + |
| 150 | +class ModelCache: |
| 151 | + lock = threading.Lock() |
| 152 | + cache: dict[str, LlamaHandler] = {} |
| 153 | + |
| 154 | + @classmethod |
| 155 | + def get(cls, key) -> Optional[LlamaHandler]: |
| 156 | + return cls.cache.get(key, None) |
| 157 | + |
| 158 | + @classmethod |
| 159 | + def store(cls, key, value): |
| 160 | + cls.cache[key] = value |
| 161 | + |
| 162 | + @classmethod |
| 163 | + def clear(cls): |
| 164 | + cls.cache = {} |
0 commit comments