Maharshi Gor
commited on
Commit
·
3a1af80
1
Parent(s):
4f5d1cb
Adds support for caching llm calls to a sqlite db and a hf dataset. Refactors repo creation logic and fixes unused temperature param.
Browse files- check_repos.py +14 -15
- src/envs.py +3 -0
- src/workflows/executors.py +1 -0
- src/workflows/llmcache.py +479 -0
- src/workflows/llms.py +123 -16
check_repos.py
CHANGED
@@ -1,26 +1,25 @@
|
|
1 |
from huggingface_hub import HfApi
|
2 |
|
3 |
-
from src.envs import QUEUE_REPO, RESULTS_REPO, TOKEN
|
4 |
|
5 |
|
6 |
-
def
|
7 |
api = HfApi(token=TOKEN)
|
8 |
-
|
9 |
-
# Check and create queue repo
|
10 |
try:
|
11 |
-
api.repo_info(repo_id=
|
12 |
-
print(f"
|
13 |
except Exception:
|
14 |
-
print(f"Creating
|
15 |
-
api.create_repo(repo_id=
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
24 |
|
25 |
|
26 |
if __name__ == "__main__":
|
|
|
1 |
from huggingface_hub import HfApi
|
2 |
|
3 |
+
from src.envs import LLM_CACHE_REPO, QUEUE_REPO, RESULTS_REPO, TOKEN
|
4 |
|
5 |
|
6 |
+
def check_and_create_dataset_repo(repo_id: str):
|
7 |
api = HfApi(token=TOKEN)
|
|
|
|
|
8 |
try:
|
9 |
+
api.repo_info(repo_id=repo_id, repo_type="dataset")
|
10 |
+
print(f"{repo_id} exists")
|
11 |
except Exception:
|
12 |
+
print(f"Creating {repo_id}")
|
13 |
+
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
|
14 |
|
15 |
+
|
16 |
+
def check_and_create_repos():
|
17 |
+
print("1. QUEUE Repository")
|
18 |
+
check_and_create_dataset_repo(QUEUE_REPO)
|
19 |
+
print("2. RESULTS Repository")
|
20 |
+
check_and_create_dataset_repo(RESULTS_REPO)
|
21 |
+
print("3. LLM Cache Repository")
|
22 |
+
check_and_create_dataset_repo(LLM_CACHE_REPO)
|
23 |
|
24 |
|
25 |
if __name__ == "__main__":
|
src/envs.py
CHANGED
@@ -15,6 +15,7 @@ OWNER = "umdclip"
|
|
15 |
REPO_ID = f"{OWNER}/quizbowl-submission"
|
16 |
QUEUE_REPO = f"{OWNER}/advcal-requests"
|
17 |
RESULTS_REPO = f"{OWNER}/model-results" # TODO: change to advcal-results after testing is done
|
|
|
18 |
|
19 |
EXAMPLES_PATH = "examples"
|
20 |
|
@@ -29,12 +30,14 @@ PLAYGROUND_DATASET_NAMES = {
|
|
29 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
30 |
|
31 |
# Local caches
|
|
|
32 |
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
|
33 |
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
|
34 |
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
|
35 |
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
|
36 |
|
37 |
|
|
|
38 |
SERVER_REFRESH_INTERVAL = 86400 # seconds (one day)
|
39 |
LEADERBOARD_REFRESH_INTERVAL = 600 # seconds (10 minutes)
|
40 |
|
|
|
15 |
REPO_ID = f"{OWNER}/quizbowl-submission"
|
16 |
QUEUE_REPO = f"{OWNER}/advcal-requests"
|
17 |
RESULTS_REPO = f"{OWNER}/model-results" # TODO: change to advcal-results after testing is done
|
18 |
+
LLM_CACHE_REPO = f"{OWNER}/advcal-llm-cache"
|
19 |
|
20 |
EXAMPLES_PATH = "examples"
|
21 |
|
|
|
30 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
31 |
|
32 |
# Local caches
|
33 |
+
LLM_CACHE_PATH = os.path.join(CACHE_PATH, "llm-cache")
|
34 |
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
|
35 |
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
|
36 |
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
|
37 |
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
|
38 |
|
39 |
|
40 |
+
LLM_CACHE_REFRESH_INTERVAL = 600 # seconds (30 minutes)
|
41 |
SERVER_REFRESH_INTERVAL = 86400 # seconds (one day)
|
42 |
LEADERBOARD_REFRESH_INTERVAL = 600 # seconds (10 minutes)
|
43 |
|
src/workflows/executors.py
CHANGED
@@ -221,6 +221,7 @@ def execute_model_step(
|
|
221 |
system=model_step.system_prompt,
|
222 |
prompt=step_result,
|
223 |
response_format=ModelResponse,
|
|
|
224 |
logprobs=logprobs,
|
225 |
)
|
226 |
|
|
|
221 |
system=model_step.system_prompt,
|
222 |
prompt=step_result,
|
223 |
response_format=ModelResponse,
|
224 |
+
temperature=model_step.temperature,
|
225 |
logprobs=logprobs,
|
226 |
)
|
227 |
|
src/workflows/llmcache.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sqlite3
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Optional
|
9 |
+
|
10 |
+
from datasets import Dataset, Features, Value
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
from loguru import logger
|
13 |
+
|
14 |
+
|
15 |
+
def load_dataset_from_hf(repo_id, local_dir):
|
16 |
+
snapshot_download(
|
17 |
+
repo_id=repo_id,
|
18 |
+
local_dir=local_dir,
|
19 |
+
repo_type="dataset",
|
20 |
+
tqdm_class=None,
|
21 |
+
etag_timeout=30,
|
22 |
+
token=os.environ["HF_TOKEN"],
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class CacheDB:
|
27 |
+
"""Handles database operations for storing and retrieving cache entries."""
|
28 |
+
|
29 |
+
def __init__(self, db_path: Path):
|
30 |
+
"""Initialize database connection.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
db_path: Path to SQLite database file
|
34 |
+
"""
|
35 |
+
self.db_path = db_path
|
36 |
+
self.lock = threading.Lock()
|
37 |
+
|
38 |
+
# Initialize the database
|
39 |
+
try:
|
40 |
+
self.initialize_db()
|
41 |
+
except Exception as e:
|
42 |
+
logger.exception(f"Failed to initialize database: {e}")
|
43 |
+
logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}")
|
44 |
+
raise
|
45 |
+
|
46 |
+
def initialize_db(self) -> None:
|
47 |
+
"""Initialize SQLite database with the required table."""
|
48 |
+
# Check if database file already exists
|
49 |
+
if self.db_path.exists():
|
50 |
+
self._verify_existing_db()
|
51 |
+
else:
|
52 |
+
self._create_new_db()
|
53 |
+
|
54 |
+
def _verify_existing_db(self) -> None:
|
55 |
+
"""Verify and repair an existing database if needed."""
|
56 |
+
try:
|
57 |
+
with sqlite3.connect(self.db_path) as conn:
|
58 |
+
cursor = conn.cursor()
|
59 |
+
self._ensure_table_exists(cursor)
|
60 |
+
self._verify_schema(cursor)
|
61 |
+
self._ensure_index_exists(cursor)
|
62 |
+
conn.commit()
|
63 |
+
logger.info(f"Using existing SQLite database at {self.db_path}")
|
64 |
+
except Exception as e:
|
65 |
+
logger.exception(f"Database corruption detected: {e}")
|
66 |
+
raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}")
|
67 |
+
|
68 |
+
def _create_new_db(self) -> None:
|
69 |
+
"""Create a new database with the required schema."""
|
70 |
+
try:
|
71 |
+
with sqlite3.connect(self.db_path) as conn:
|
72 |
+
cursor = conn.cursor()
|
73 |
+
self._create_table(cursor)
|
74 |
+
self._ensure_index_exists(cursor)
|
75 |
+
conn.commit()
|
76 |
+
logger.info(f"Initialized new SQLite database at {self.db_path}")
|
77 |
+
except Exception as e:
|
78 |
+
logger.exception(f"Failed to initialize SQLite database: {e}")
|
79 |
+
raise
|
80 |
+
|
81 |
+
def _ensure_table_exists(self, cursor) -> None:
|
82 |
+
"""Check if the llm_cache table exists and create it if not."""
|
83 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'")
|
84 |
+
if not cursor.fetchone():
|
85 |
+
self._create_table(cursor)
|
86 |
+
logger.info("Created missing llm_cache table")
|
87 |
+
|
88 |
+
def _create_table(self, cursor) -> None:
|
89 |
+
"""Create the llm_cache table with the required schema."""
|
90 |
+
cursor.execute("""
|
91 |
+
CREATE TABLE IF NOT EXISTS llm_cache (
|
92 |
+
key TEXT PRIMARY KEY,
|
93 |
+
request_json TEXT,
|
94 |
+
response_json TEXT,
|
95 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
96 |
+
)
|
97 |
+
""")
|
98 |
+
|
99 |
+
def _verify_schema(self, cursor) -> None:
|
100 |
+
"""Verify that the table schema has all required columns."""
|
101 |
+
cursor.execute("PRAGMA table_info(llm_cache)")
|
102 |
+
columns = {row[1] for row in cursor.fetchall()}
|
103 |
+
required_columns = {"key", "request_json", "response_json", "created_at"}
|
104 |
+
|
105 |
+
if not required_columns.issubset(columns):
|
106 |
+
missing = required_columns - columns
|
107 |
+
raise ValueError(f"Database schema is corrupted. Missing columns: {missing}")
|
108 |
+
|
109 |
+
def _ensure_index_exists(self, cursor) -> None:
|
110 |
+
"""Create an index on the key column for faster lookups."""
|
111 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)")
|
112 |
+
|
113 |
+
def get(self, key: str) -> Optional[dict[str, Any]]:
|
114 |
+
"""Get cached entry by key.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
key: Cache key to look up
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Dict containing the request and response or None if not found
|
121 |
+
"""
|
122 |
+
try:
|
123 |
+
with sqlite3.connect(self.db_path) as conn:
|
124 |
+
conn.row_factory = sqlite3.Row
|
125 |
+
cursor = conn.cursor()
|
126 |
+
cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,))
|
127 |
+
result = cursor.fetchone()
|
128 |
+
|
129 |
+
if result:
|
130 |
+
logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}")
|
131 |
+
return {
|
132 |
+
"request": result["request_json"],
|
133 |
+
"response": result["response_json"],
|
134 |
+
}
|
135 |
+
|
136 |
+
logger.debug(f"Cache miss for key: {key}")
|
137 |
+
return None
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"Error retrieving from cache: {e}")
|
140 |
+
return None
|
141 |
+
|
142 |
+
def set(self, key: str, request_json: str, response_json: str) -> bool:
|
143 |
+
"""Set entry in cache.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
key: Cache key
|
147 |
+
request_json: JSON string of request parameters
|
148 |
+
response_json: JSON string of response
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
True if successful, False otherwise
|
152 |
+
"""
|
153 |
+
with self.lock:
|
154 |
+
try:
|
155 |
+
with sqlite3.connect(self.db_path) as conn:
|
156 |
+
cursor = conn.cursor()
|
157 |
+
cursor.execute(
|
158 |
+
"INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
|
159 |
+
(key, request_json, response_json),
|
160 |
+
)
|
161 |
+
conn.commit()
|
162 |
+
logger.debug(f"Saved response to cache with key: {key}, response: {response_json}")
|
163 |
+
return True
|
164 |
+
except Exception as e:
|
165 |
+
logger.error(f"Failed to save to SQLite cache: {e}")
|
166 |
+
return False
|
167 |
+
|
168 |
+
def get_all_entries(self) -> dict[str, dict[str, Any]]:
|
169 |
+
"""Get all cache entries from the database."""
|
170 |
+
cache = {}
|
171 |
+
try:
|
172 |
+
with sqlite3.connect(self.db_path) as conn:
|
173 |
+
conn.row_factory = sqlite3.Row
|
174 |
+
cursor = conn.cursor()
|
175 |
+
cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at")
|
176 |
+
|
177 |
+
for row in cursor.fetchall():
|
178 |
+
cache[row["key"]] = {
|
179 |
+
"request": row["request_json"],
|
180 |
+
"response": row["response_json"],
|
181 |
+
}
|
182 |
+
|
183 |
+
logger.debug(f"Retrieved {len(cache)} entries from cache database")
|
184 |
+
return cache
|
185 |
+
except Exception as e:
|
186 |
+
logger.error(f"Error retrieving all cache entries: {e}")
|
187 |
+
return {}
|
188 |
+
|
189 |
+
def clear(self) -> bool:
|
190 |
+
"""Clear all cache entries.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
True if successful, False otherwise
|
194 |
+
"""
|
195 |
+
with self.lock:
|
196 |
+
try:
|
197 |
+
with sqlite3.connect(self.db_path) as conn:
|
198 |
+
cursor = conn.cursor()
|
199 |
+
cursor.execute("DELETE FROM llm_cache")
|
200 |
+
conn.commit()
|
201 |
+
logger.info("Cache cleared")
|
202 |
+
return True
|
203 |
+
except Exception as e:
|
204 |
+
logger.error(f"Failed to clear cache: {e}")
|
205 |
+
return False
|
206 |
+
|
207 |
+
def get_existing_keys(self) -> set:
|
208 |
+
"""Get all existing keys in the database.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Set of keys
|
212 |
+
"""
|
213 |
+
existing_keys = set()
|
214 |
+
try:
|
215 |
+
with sqlite3.connect(self.db_path) as conn:
|
216 |
+
cursor = conn.cursor()
|
217 |
+
cursor.execute("SELECT key FROM llm_cache")
|
218 |
+
for row in cursor.fetchall():
|
219 |
+
existing_keys.add(row[0])
|
220 |
+
return existing_keys
|
221 |
+
except Exception as e:
|
222 |
+
logger.error(f"Error retrieving existing keys: {e}")
|
223 |
+
return set()
|
224 |
+
|
225 |
+
def bulk_insert(self, items: list, update: bool = False) -> int:
|
226 |
+
"""Insert multiple items into the cache.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
items: List of (key, request_json, response_json) tuples
|
230 |
+
update: Whether to update existing entries
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
Number of items inserted
|
234 |
+
"""
|
235 |
+
count = 0
|
236 |
+
UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE"
|
237 |
+
with self.lock:
|
238 |
+
try:
|
239 |
+
with sqlite3.connect(self.db_path) as conn:
|
240 |
+
cursor = conn.cursor()
|
241 |
+
cursor.executemany(
|
242 |
+
f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
|
243 |
+
items,
|
244 |
+
)
|
245 |
+
count = cursor.rowcount
|
246 |
+
conn.commit()
|
247 |
+
return count
|
248 |
+
except Exception as e:
|
249 |
+
logger.error(f"Error during bulk insert: {e}")
|
250 |
+
return 0
|
251 |
+
|
252 |
+
|
253 |
+
class LLMCache:
|
254 |
+
def __init__(
|
255 |
+
self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False
|
256 |
+
):
|
257 |
+
self.cache_dir = Path(cache_dir)
|
258 |
+
self.db_path = self.cache_dir / "llm_cache.db"
|
259 |
+
self.hf_repo_id = hf_repo
|
260 |
+
self.cache_sync_interval = cache_sync_interval
|
261 |
+
self.last_sync_time = time.time()
|
262 |
+
|
263 |
+
# Create cache directory if it doesn't exist
|
264 |
+
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
265 |
+
|
266 |
+
# Initialize CacheDB
|
267 |
+
self.db = CacheDB(self.db_path)
|
268 |
+
if reset:
|
269 |
+
self.db.clear()
|
270 |
+
|
271 |
+
# Try to load from HF dataset if available
|
272 |
+
try:
|
273 |
+
self._load_cache_from_hf()
|
274 |
+
except Exception as e:
|
275 |
+
logger.warning(f"Failed to load cache from HF dataset: {e}")
|
276 |
+
|
277 |
+
def response_format_to_dict(self, response_format: Any) -> dict[str, Any]:
|
278 |
+
"""Convert a response format to a dict."""
|
279 |
+
# If it's a Pydantic model, use its schema
|
280 |
+
if hasattr(response_format, "model_json_schema"):
|
281 |
+
response_format = response_format.model_json_schema()
|
282 |
+
|
283 |
+
# If it's a Pydantic model, use its dump
|
284 |
+
elif hasattr(response_format, "model_dump"):
|
285 |
+
response_format = response_format.model_dump()
|
286 |
+
|
287 |
+
if not isinstance(response_format, dict):
|
288 |
+
response_format = {"value": str(response_format)}
|
289 |
+
|
290 |
+
return response_format
|
291 |
+
|
292 |
+
def _generate_key(
|
293 |
+
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None
|
294 |
+
) -> str:
|
295 |
+
"""Generate a unique key for caching based on inputs."""
|
296 |
+
response_format_dict = self.response_format_to_dict(response_format)
|
297 |
+
response_format_str = json.dumps(response_format_dict, sort_keys=True)
|
298 |
+
# Include temperature in the key
|
299 |
+
key_content = f"{model}:{system}:{prompt}:{response_format_str}"
|
300 |
+
if temperature is not None:
|
301 |
+
key_content += f":{temperature:.2f}"
|
302 |
+
return hashlib.md5(key_content.encode()).hexdigest()
|
303 |
+
|
304 |
+
def _create_request_json(
|
305 |
+
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None
|
306 |
+
) -> str:
|
307 |
+
"""Create JSON string from request parameters."""
|
308 |
+
logger.info(f"Creating request JSON with temperature: {temperature}")
|
309 |
+
request_data = {
|
310 |
+
"model": model,
|
311 |
+
"system": system,
|
312 |
+
"prompt": prompt,
|
313 |
+
"response_format": self.response_format_to_dict(response_format),
|
314 |
+
"temperature": temperature,
|
315 |
+
}
|
316 |
+
return json.dumps(request_data)
|
317 |
+
|
318 |
+
def _check_request_match(
|
319 |
+
self,
|
320 |
+
cached_request: dict[str, Any],
|
321 |
+
model: str,
|
322 |
+
system: str,
|
323 |
+
prompt: str,
|
324 |
+
response_format: Any,
|
325 |
+
temperature: float | None,
|
326 |
+
) -> bool:
|
327 |
+
"""Check if the cached request matches the new request."""
|
328 |
+
# Check each field and log any mismatches
|
329 |
+
if cached_request["model"] != model:
|
330 |
+
logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}")
|
331 |
+
return False
|
332 |
+
if cached_request["system"] != system:
|
333 |
+
logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}")
|
334 |
+
return False
|
335 |
+
if cached_request["prompt"] != prompt:
|
336 |
+
logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}")
|
337 |
+
return False
|
338 |
+
response_format_dict = self.response_format_to_dict(response_format)
|
339 |
+
if cached_request["response_format"] != response_format_dict:
|
340 |
+
logger.debug(
|
341 |
+
f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}"
|
342 |
+
)
|
343 |
+
return False
|
344 |
+
if cached_request["temperature"] != temperature:
|
345 |
+
logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}")
|
346 |
+
return False
|
347 |
+
|
348 |
+
return True
|
349 |
+
|
350 |
+
def get(
|
351 |
+
self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None
|
352 |
+
) -> Optional[dict[str, Any]]:
|
353 |
+
"""Get cached response if it exists."""
|
354 |
+
key = self._generate_key(model, system, prompt, response_format, temperature)
|
355 |
+
result = self.db.get(key)
|
356 |
+
|
357 |
+
if not result:
|
358 |
+
return None
|
359 |
+
request_dict = json.loads(result["request"])
|
360 |
+
if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature):
|
361 |
+
logger.warning(f"Cached request does not match new request for key: {key}")
|
362 |
+
return None
|
363 |
+
|
364 |
+
return json.loads(result["response"])
|
365 |
+
|
366 |
+
def set(
|
367 |
+
self,
|
368 |
+
model: str,
|
369 |
+
system: str,
|
370 |
+
prompt: str,
|
371 |
+
response_format: dict[str, Any],
|
372 |
+
temperature: float | None,
|
373 |
+
response: dict[str, Any],
|
374 |
+
) -> None:
|
375 |
+
"""Set response in cache and sync if needed."""
|
376 |
+
key = self._generate_key(model, system, prompt, response_format, temperature)
|
377 |
+
request_json = self._create_request_json(model, system, prompt, response_format, temperature)
|
378 |
+
response_json = json.dumps(response)
|
379 |
+
|
380 |
+
success = self.db.set(key, request_json, response_json)
|
381 |
+
|
382 |
+
# Check if we should sync to HF
|
383 |
+
if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval):
|
384 |
+
try:
|
385 |
+
self.sync_to_hf()
|
386 |
+
self.last_sync_time = time.time()
|
387 |
+
except Exception as e:
|
388 |
+
logger.error(f"Failed to sync cache to HF dataset: {e}")
|
389 |
+
|
390 |
+
def _load_cache_from_hf(self) -> None:
|
391 |
+
"""Load cache from HF dataset if it exists and merge with local cache."""
|
392 |
+
if not self.hf_repo_id:
|
393 |
+
return
|
394 |
+
|
395 |
+
try:
|
396 |
+
# Check for new commits before loading the dataset
|
397 |
+
dataset = load_dataset_from_hf(self.hf_repo_id, self.cache_dir / "hf_cache")
|
398 |
+
if dataset:
|
399 |
+
existing_keys = self.db.get_existing_keys()
|
400 |
+
|
401 |
+
# Prepare batch items for insertion
|
402 |
+
items_to_insert = []
|
403 |
+
for item in dataset:
|
404 |
+
key = item["key"]
|
405 |
+
# Only update if not in local cache to prioritize local changes
|
406 |
+
if key in existing_keys:
|
407 |
+
continue
|
408 |
+
# Create request JSON
|
409 |
+
request_data = {
|
410 |
+
"model": item["model"],
|
411 |
+
"system": item["system"],
|
412 |
+
"prompt": item["prompt"],
|
413 |
+
"temperature": item["temperature"],
|
414 |
+
"response_format": None, # We can't fully reconstruct this
|
415 |
+
}
|
416 |
+
|
417 |
+
items_to_insert.append(
|
418 |
+
(
|
419 |
+
key,
|
420 |
+
json.dumps(request_data),
|
421 |
+
item["response"], # This is already a JSON string
|
422 |
+
)
|
423 |
+
)
|
424 |
+
logger.info(
|
425 |
+
f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}"
|
426 |
+
)
|
427 |
+
|
428 |
+
# Bulk insert new items
|
429 |
+
if items_to_insert:
|
430 |
+
inserted_count = self.db.bulk_insert(items_to_insert)
|
431 |
+
logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache")
|
432 |
+
else:
|
433 |
+
logger.info("No new items to merge from HF dataset")
|
434 |
+
except Exception as e:
|
435 |
+
logger.warning(f"Could not load cache from HF dataset: {e}")
|
436 |
+
|
437 |
+
def get_all_entries(self) -> dict[str, dict[str, Any]]:
|
438 |
+
"""Get all cache entries from the database."""
|
439 |
+
cache = self.db.get_all_entries()
|
440 |
+
entries = {}
|
441 |
+
for key, entry in cache.items():
|
442 |
+
request = json.loads(entry["request"])
|
443 |
+
response = json.loads(entry["response"])
|
444 |
+
entries[key] = {"request": request, "response": response}
|
445 |
+
return entries
|
446 |
+
|
447 |
+
def sync_to_hf(self) -> None:
|
448 |
+
"""Sync cache to HF dataset."""
|
449 |
+
if not self.hf_repo_id:
|
450 |
+
return
|
451 |
+
|
452 |
+
# Get all entries from the database
|
453 |
+
cache = self.db.get_all_entries()
|
454 |
+
|
455 |
+
# Convert cache to dataset format
|
456 |
+
entries = []
|
457 |
+
for key, entry in cache.items():
|
458 |
+
request = json.loads(entry["request"])
|
459 |
+
response_str = entry["response"]
|
460 |
+
entries.append(
|
461 |
+
{
|
462 |
+
"key": key,
|
463 |
+
"model": request["model"],
|
464 |
+
"system": request["system"],
|
465 |
+
"prompt": request["prompt"],
|
466 |
+
"response_format": request["response_format"],
|
467 |
+
"temperature": request["temperature"],
|
468 |
+
"response": response_str,
|
469 |
+
}
|
470 |
+
)
|
471 |
+
|
472 |
+
# Create and push dataset
|
473 |
+
dataset = Dataset.from_list(entries)
|
474 |
+
dataset.push_to_hub(self.hf_repo_id, private=True)
|
475 |
+
logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}")
|
476 |
+
|
477 |
+
def clear(self) -> None:
|
478 |
+
"""Clear all cache entries."""
|
479 |
+
self.db.clear()
|
src/workflows/llms.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
# %%
|
|
|
2 |
import json
|
3 |
import os
|
4 |
-
from typing import Optional
|
5 |
|
6 |
import cohere
|
7 |
import numpy as np
|
8 |
from langchain_anthropic import ChatAnthropic
|
9 |
from langchain_cohere import ChatCohere
|
|
|
10 |
from langchain_openai import ChatOpenAI
|
11 |
from loguru import logger
|
12 |
from openai import OpenAI
|
@@ -14,6 +16,10 @@ from pydantic import BaseModel, Field
|
|
14 |
from rich import print as rprint
|
15 |
|
16 |
from .configs import AVAILABLE_MODELS
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
@@ -30,7 +36,7 @@ class LLMOutput(BaseModel):
|
|
30 |
logprob: Optional[float] = Field(None, description="The log probability of the response")
|
31 |
|
32 |
|
33 |
-
def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
|
34 |
output = llm.invoke([("system", system), ("human", prompt)])
|
35 |
ai_message = output["raw"]
|
36 |
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
|
@@ -38,7 +44,9 @@ def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
|
|
38 |
return {"content": content_str, "output": output["parsed"].model_dump()}
|
39 |
|
40 |
|
41 |
-
def _cohere_completion(
|
|
|
|
|
42 |
messages = [
|
43 |
{"role": "system", "content": system},
|
44 |
{"role": "user", "content": prompt},
|
@@ -49,6 +57,7 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
|
|
49 |
messages=messages,
|
50 |
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
|
51 |
logprobs=logprobs,
|
|
|
52 |
)
|
53 |
output = {}
|
54 |
output["content"] = response.message.content[0].text
|
@@ -59,12 +68,16 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
|
|
59 |
return output
|
60 |
|
61 |
|
62 |
-
def _openai_langchain_completion(
|
63 |
-
|
|
|
|
|
64 |
return _get_langchain_chat_output(llm, system, prompt)
|
65 |
|
66 |
|
67 |
-
def _openai_completion(
|
|
|
|
|
68 |
messages = [
|
69 |
{"role": "system", "content": system},
|
70 |
{"role": "user", "content": prompt},
|
@@ -75,6 +88,7 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
|
|
75 |
messages=messages,
|
76 |
response_format=response_model,
|
77 |
logprobs=logprobs,
|
|
|
78 |
)
|
79 |
output = {}
|
80 |
output["content"] = response.choices[0].message.content
|
@@ -85,14 +99,18 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
|
|
85 |
return output
|
86 |
|
87 |
|
88 |
-
def _anthropic_completion(
|
89 |
-
|
|
|
|
|
90 |
return _get_langchain_chat_output(llm, system, prompt)
|
91 |
|
92 |
|
93 |
-
def
|
|
|
|
|
94 |
"""
|
95 |
-
Generate a completion from an LLM provider with structured output.
|
96 |
|
97 |
Args:
|
98 |
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
@@ -116,20 +134,69 @@ def completion(model: str, system: str, prompt: str, response_format, logprobs:
|
|
116 |
model_name = AVAILABLE_MODELS[model]["model"]
|
117 |
provider = model.split("/")[0]
|
118 |
if provider == "Cohere":
|
119 |
-
return _cohere_completion(model_name, system, prompt, response_format, logprobs)
|
120 |
elif provider == "OpenAI":
|
121 |
if _openai_is_json_mode_supported(model_name):
|
122 |
-
return _openai_completion(model_name, system, prompt, response_format, logprobs)
|
|
|
|
|
123 |
else:
|
124 |
-
return _openai_langchain_completion(model_name, system, prompt, response_format,
|
125 |
elif provider == "Anthropic":
|
126 |
if logprobs:
|
127 |
-
raise ValueError("Anthropic
|
128 |
-
return _anthropic_completion(model_name, system, prompt, response_format)
|
129 |
else:
|
130 |
raise ValueError(f"Provider {provider} not supported")
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
# %%
|
134 |
if __name__ == "__main__":
|
135 |
from tqdm import tqdm
|
@@ -142,12 +209,52 @@ if __name__ == "__main__":
|
|
142 |
answer: str = Field(description="The short answer to the question")
|
143 |
explanation: str = Field(description="5 words terse best explanation of the answer.")
|
144 |
|
145 |
-
models = AVAILABLE_MODELS.keys()
|
146 |
system = "You are an accurate and concise explainer of scientific concepts."
|
147 |
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
for model in tqdm(models):
|
150 |
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
151 |
rprint(response)
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
# %%
|
|
|
1 |
# %%
|
2 |
+
|
3 |
import json
|
4 |
import os
|
5 |
+
from typing import Any, Optional
|
6 |
|
7 |
import cohere
|
8 |
import numpy as np
|
9 |
from langchain_anthropic import ChatAnthropic
|
10 |
from langchain_cohere import ChatCohere
|
11 |
+
from langchain_core.language_models import BaseChatModel
|
12 |
from langchain_openai import ChatOpenAI
|
13 |
from loguru import logger
|
14 |
from openai import OpenAI
|
|
|
16 |
from rich import print as rprint
|
17 |
|
18 |
from .configs import AVAILABLE_MODELS
|
19 |
+
from .llmcache import LLMCache
|
20 |
+
|
21 |
+
# Initialize global cache
|
22 |
+
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache")
|
23 |
|
24 |
|
25 |
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
|
|
36 |
logprob: Optional[float] = Field(None, description="The log probability of the response")
|
37 |
|
38 |
|
39 |
+
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
|
40 |
output = llm.invoke([("system", system), ("human", prompt)])
|
41 |
ai_message = output["raw"]
|
42 |
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
|
|
|
44 |
return {"content": content_str, "output": output["parsed"].model_dump()}
|
45 |
|
46 |
|
47 |
+
def _cohere_completion(
|
48 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
|
49 |
+
) -> str:
|
50 |
messages = [
|
51 |
{"role": "system", "content": system},
|
52 |
{"role": "user", "content": prompt},
|
|
|
57 |
messages=messages,
|
58 |
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
|
59 |
logprobs=logprobs,
|
60 |
+
temperature=temperature,
|
61 |
)
|
62 |
output = {}
|
63 |
output["content"] = response.message.content[0].text
|
|
|
68 |
return output
|
69 |
|
70 |
|
71 |
+
def _openai_langchain_completion(
|
72 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None
|
73 |
+
) -> str:
|
74 |
+
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
|
75 |
return _get_langchain_chat_output(llm, system, prompt)
|
76 |
|
77 |
|
78 |
+
def _openai_completion(
|
79 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
|
80 |
+
) -> str:
|
81 |
messages = [
|
82 |
{"role": "system", "content": system},
|
83 |
{"role": "user", "content": prompt},
|
|
|
88 |
messages=messages,
|
89 |
response_format=response_model,
|
90 |
logprobs=logprobs,
|
91 |
+
temperature=temperature,
|
92 |
)
|
93 |
output = {}
|
94 |
output["content"] = response.choices[0].message.content
|
|
|
99 |
return output
|
100 |
|
101 |
|
102 |
+
def _anthropic_completion(
|
103 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None
|
104 |
+
) -> str:
|
105 |
+
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
|
106 |
return _get_langchain_chat_output(llm, system, prompt)
|
107 |
|
108 |
|
109 |
+
def _llm_completion(
|
110 |
+
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
|
111 |
+
) -> dict[str, Any]:
|
112 |
"""
|
113 |
+
Generate a completion from an LLM provider with structured output without caching.
|
114 |
|
115 |
Args:
|
116 |
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
|
|
134 |
model_name = AVAILABLE_MODELS[model]["model"]
|
135 |
provider = model.split("/")[0]
|
136 |
if provider == "Cohere":
|
137 |
+
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
|
138 |
elif provider == "OpenAI":
|
139 |
if _openai_is_json_mode_supported(model_name):
|
140 |
+
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
|
141 |
+
elif logprobs:
|
142 |
+
raise ValueError(f"{model} does not support logprobs feature.")
|
143 |
else:
|
144 |
+
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
|
145 |
elif provider == "Anthropic":
|
146 |
if logprobs:
|
147 |
+
raise ValueError("Anthropic models do not support logprobs")
|
148 |
+
return _anthropic_completion(model_name, system, prompt, response_format, temperature)
|
149 |
else:
|
150 |
raise ValueError(f"Provider {provider} not supported")
|
151 |
|
152 |
|
153 |
+
def completion(
|
154 |
+
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
|
155 |
+
) -> dict[str, Any]:
|
156 |
+
"""
|
157 |
+
Generate a completion from an LLM provider with structured output with caching.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
161 |
+
system (str): System prompt/instructions for the model
|
162 |
+
prompt (str): User prompt/input
|
163 |
+
response_format: Pydantic model defining the expected response structure
|
164 |
+
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
|
165 |
+
Note: Not supported by Anthropic models.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
dict: Contains:
|
169 |
+
- output: The structured response matching response_format
|
170 |
+
- logprob: (optional) Sum of log probabilities if logprobs=True
|
171 |
+
- prob: (optional) Exponential of logprob if logprobs=True
|
172 |
+
|
173 |
+
Raises:
|
174 |
+
ValueError: If logprobs=True with Anthropic models
|
175 |
+
"""
|
176 |
+
# Check cache first
|
177 |
+
cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
|
178 |
+
if cached_response is not None:
|
179 |
+
logger.info(f"Cache hit for model {model}")
|
180 |
+
return cached_response
|
181 |
+
|
182 |
+
logger.info(f"Cache miss for model {model}, calling API")
|
183 |
+
|
184 |
+
# Continue with the original implementation for cache miss
|
185 |
+
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
|
186 |
+
|
187 |
+
# Update cache with the new response
|
188 |
+
llm_cache.set(
|
189 |
+
model,
|
190 |
+
system,
|
191 |
+
prompt,
|
192 |
+
response_format,
|
193 |
+
temperature,
|
194 |
+
response,
|
195 |
+
)
|
196 |
+
|
197 |
+
return response
|
198 |
+
|
199 |
+
|
200 |
# %%
|
201 |
if __name__ == "__main__":
|
202 |
from tqdm import tqdm
|
|
|
209 |
answer: str = Field(description="The short answer to the question")
|
210 |
explanation: str = Field(description="5 words terse best explanation of the answer.")
|
211 |
|
212 |
+
models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
|
213 |
system = "You are an accurate and concise explainer of scientific concepts."
|
214 |
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
|
215 |
|
216 |
+
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True)
|
217 |
+
|
218 |
+
# First call - should be a cache miss
|
219 |
+
logger.info("First call - should be a cache miss")
|
220 |
+
for model in tqdm(models):
|
221 |
+
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
222 |
+
rprint(response)
|
223 |
+
|
224 |
+
# Second call - should be a cache hit
|
225 |
+
logger.info("Second call - should be a cache hit")
|
226 |
for model in tqdm(models):
|
227 |
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
228 |
rprint(response)
|
229 |
|
230 |
+
# Slightly different prompt - should be a cache miss
|
231 |
+
logger.info("Different prompt - should be a cache miss")
|
232 |
+
prompt2 = "Which planet is closest to the sun? Answer directly."
|
233 |
+
for model in tqdm(models):
|
234 |
+
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
|
235 |
+
rprint(response)
|
236 |
+
|
237 |
+
# Get cache entries count from SQLite
|
238 |
+
try:
|
239 |
+
cache_entries = llm_cache.get_all_entries()
|
240 |
+
logger.info(f"Cache now has {len(cache_entries)} items")
|
241 |
+
except Exception as e:
|
242 |
+
logger.error(f"Failed to get cache entries: {e}")
|
243 |
+
|
244 |
+
# Test adding entry with temperature parameter
|
245 |
+
logger.info("Testing with temperature parameter")
|
246 |
+
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
|
247 |
+
rprint(response)
|
248 |
+
|
249 |
+
# Demonstrate forced sync to HF if repo is configured
|
250 |
+
if llm_cache.hf_repo_id:
|
251 |
+
logger.info("Forcing sync to HF dataset")
|
252 |
+
try:
|
253 |
+
llm_cache.sync_to_hf()
|
254 |
+
logger.info("Successfully synced to HF dataset")
|
255 |
+
except Exception as e:
|
256 |
+
logger.exception(f"Failed to sync to HF: {e}")
|
257 |
+
else:
|
258 |
+
logger.info("HF repo not configured, skipping sync test")
|
259 |
+
|
260 |
# %%
|