Maharshi Gor commited on
Commit
3a1af80
·
1 Parent(s): 4f5d1cb

Adds support for caching llm calls to a sqlite db and a hf dataset. Refactors repo creation logic and fixes unused temperature param.

Browse files
check_repos.py CHANGED
@@ -1,26 +1,25 @@
1
  from huggingface_hub import HfApi
2
 
3
- from src.envs import QUEUE_REPO, RESULTS_REPO, TOKEN
4
 
5
 
6
- def check_and_create_repos():
7
  api = HfApi(token=TOKEN)
8
-
9
- # Check and create queue repo
10
  try:
11
- api.repo_info(repo_id=QUEUE_REPO, repo_type="dataset")
12
- print(f"Queue repository {QUEUE_REPO} exists")
13
  except Exception:
14
- print(f"Creating queue repository {QUEUE_REPO}")
15
- api.create_repo(repo_id=QUEUE_REPO, repo_type="dataset", exist_ok=True, private=False)
16
 
17
- # Check and create results repo
18
- try:
19
- api.repo_info(repo_id=RESULTS_REPO, repo_type="dataset")
20
- print(f"Results repository {RESULTS_REPO} exists")
21
- except Exception:
22
- print(f"Creating results repository {RESULTS_REPO}")
23
- api.create_repo(repo_id=RESULTS_REPO, repo_type="dataset", exist_ok=True, private=False)
 
24
 
25
 
26
  if __name__ == "__main__":
 
1
  from huggingface_hub import HfApi
2
 
3
+ from src.envs import LLM_CACHE_REPO, QUEUE_REPO, RESULTS_REPO, TOKEN
4
 
5
 
6
+ def check_and_create_dataset_repo(repo_id: str):
7
  api = HfApi(token=TOKEN)
 
 
8
  try:
9
+ api.repo_info(repo_id=repo_id, repo_type="dataset")
10
+ print(f"{repo_id} exists")
11
  except Exception:
12
+ print(f"Creating {repo_id}")
13
+ api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
14
 
15
+
16
+ def check_and_create_repos():
17
+ print("1. QUEUE Repository")
18
+ check_and_create_dataset_repo(QUEUE_REPO)
19
+ print("2. RESULTS Repository")
20
+ check_and_create_dataset_repo(RESULTS_REPO)
21
+ print("3. LLM Cache Repository")
22
+ check_and_create_dataset_repo(LLM_CACHE_REPO)
23
 
24
 
25
  if __name__ == "__main__":
src/envs.py CHANGED
@@ -15,6 +15,7 @@ OWNER = "umdclip"
15
  REPO_ID = f"{OWNER}/quizbowl-submission"
16
  QUEUE_REPO = f"{OWNER}/advcal-requests"
17
  RESULTS_REPO = f"{OWNER}/model-results" # TODO: change to advcal-results after testing is done
 
18
 
19
  EXAMPLES_PATH = "examples"
20
 
@@ -29,12 +30,14 @@ PLAYGROUND_DATASET_NAMES = {
29
  CACHE_PATH = os.getenv("HF_HOME", ".")
30
 
31
  # Local caches
 
32
  EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
33
  EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
34
  EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
35
  EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
36
 
37
 
 
38
  SERVER_REFRESH_INTERVAL = 86400 # seconds (one day)
39
  LEADERBOARD_REFRESH_INTERVAL = 600 # seconds (10 minutes)
40
 
 
15
  REPO_ID = f"{OWNER}/quizbowl-submission"
16
  QUEUE_REPO = f"{OWNER}/advcal-requests"
17
  RESULTS_REPO = f"{OWNER}/model-results" # TODO: change to advcal-results after testing is done
18
+ LLM_CACHE_REPO = f"{OWNER}/advcal-llm-cache"
19
 
20
  EXAMPLES_PATH = "examples"
21
 
 
30
  CACHE_PATH = os.getenv("HF_HOME", ".")
31
 
32
  # Local caches
33
+ LLM_CACHE_PATH = os.path.join(CACHE_PATH, "llm-cache")
34
  EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
35
  EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
36
  EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
37
  EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
38
 
39
 
40
+ LLM_CACHE_REFRESH_INTERVAL = 600 # seconds (30 minutes)
41
  SERVER_REFRESH_INTERVAL = 86400 # seconds (one day)
42
  LEADERBOARD_REFRESH_INTERVAL = 600 # seconds (10 minutes)
43
 
src/workflows/executors.py CHANGED
@@ -221,6 +221,7 @@ def execute_model_step(
221
  system=model_step.system_prompt,
222
  prompt=step_result,
223
  response_format=ModelResponse,
 
224
  logprobs=logprobs,
225
  )
226
 
 
221
  system=model_step.system_prompt,
222
  prompt=step_result,
223
  response_format=ModelResponse,
224
+ temperature=model_step.temperature,
225
  logprobs=logprobs,
226
  )
227
 
src/workflows/llmcache.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import os
4
+ import sqlite3
5
+ import threading
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Optional
9
+
10
+ from datasets import Dataset, Features, Value
11
+ from huggingface_hub import snapshot_download
12
+ from loguru import logger
13
+
14
+
15
+ def load_dataset_from_hf(repo_id, local_dir):
16
+ snapshot_download(
17
+ repo_id=repo_id,
18
+ local_dir=local_dir,
19
+ repo_type="dataset",
20
+ tqdm_class=None,
21
+ etag_timeout=30,
22
+ token=os.environ["HF_TOKEN"],
23
+ )
24
+
25
+
26
+ class CacheDB:
27
+ """Handles database operations for storing and retrieving cache entries."""
28
+
29
+ def __init__(self, db_path: Path):
30
+ """Initialize database connection.
31
+
32
+ Args:
33
+ db_path: Path to SQLite database file
34
+ """
35
+ self.db_path = db_path
36
+ self.lock = threading.Lock()
37
+
38
+ # Initialize the database
39
+ try:
40
+ self.initialize_db()
41
+ except Exception as e:
42
+ logger.exception(f"Failed to initialize database: {e}")
43
+ logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}")
44
+ raise
45
+
46
+ def initialize_db(self) -> None:
47
+ """Initialize SQLite database with the required table."""
48
+ # Check if database file already exists
49
+ if self.db_path.exists():
50
+ self._verify_existing_db()
51
+ else:
52
+ self._create_new_db()
53
+
54
+ def _verify_existing_db(self) -> None:
55
+ """Verify and repair an existing database if needed."""
56
+ try:
57
+ with sqlite3.connect(self.db_path) as conn:
58
+ cursor = conn.cursor()
59
+ self._ensure_table_exists(cursor)
60
+ self._verify_schema(cursor)
61
+ self._ensure_index_exists(cursor)
62
+ conn.commit()
63
+ logger.info(f"Using existing SQLite database at {self.db_path}")
64
+ except Exception as e:
65
+ logger.exception(f"Database corruption detected: {e}")
66
+ raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}")
67
+
68
+ def _create_new_db(self) -> None:
69
+ """Create a new database with the required schema."""
70
+ try:
71
+ with sqlite3.connect(self.db_path) as conn:
72
+ cursor = conn.cursor()
73
+ self._create_table(cursor)
74
+ self._ensure_index_exists(cursor)
75
+ conn.commit()
76
+ logger.info(f"Initialized new SQLite database at {self.db_path}")
77
+ except Exception as e:
78
+ logger.exception(f"Failed to initialize SQLite database: {e}")
79
+ raise
80
+
81
+ def _ensure_table_exists(self, cursor) -> None:
82
+ """Check if the llm_cache table exists and create it if not."""
83
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'")
84
+ if not cursor.fetchone():
85
+ self._create_table(cursor)
86
+ logger.info("Created missing llm_cache table")
87
+
88
+ def _create_table(self, cursor) -> None:
89
+ """Create the llm_cache table with the required schema."""
90
+ cursor.execute("""
91
+ CREATE TABLE IF NOT EXISTS llm_cache (
92
+ key TEXT PRIMARY KEY,
93
+ request_json TEXT,
94
+ response_json TEXT,
95
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
96
+ )
97
+ """)
98
+
99
+ def _verify_schema(self, cursor) -> None:
100
+ """Verify that the table schema has all required columns."""
101
+ cursor.execute("PRAGMA table_info(llm_cache)")
102
+ columns = {row[1] for row in cursor.fetchall()}
103
+ required_columns = {"key", "request_json", "response_json", "created_at"}
104
+
105
+ if not required_columns.issubset(columns):
106
+ missing = required_columns - columns
107
+ raise ValueError(f"Database schema is corrupted. Missing columns: {missing}")
108
+
109
+ def _ensure_index_exists(self, cursor) -> None:
110
+ """Create an index on the key column for faster lookups."""
111
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)")
112
+
113
+ def get(self, key: str) -> Optional[dict[str, Any]]:
114
+ """Get cached entry by key.
115
+
116
+ Args:
117
+ key: Cache key to look up
118
+
119
+ Returns:
120
+ Dict containing the request and response or None if not found
121
+ """
122
+ try:
123
+ with sqlite3.connect(self.db_path) as conn:
124
+ conn.row_factory = sqlite3.Row
125
+ cursor = conn.cursor()
126
+ cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,))
127
+ result = cursor.fetchone()
128
+
129
+ if result:
130
+ logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}")
131
+ return {
132
+ "request": result["request_json"],
133
+ "response": result["response_json"],
134
+ }
135
+
136
+ logger.debug(f"Cache miss for key: {key}")
137
+ return None
138
+ except Exception as e:
139
+ logger.error(f"Error retrieving from cache: {e}")
140
+ return None
141
+
142
+ def set(self, key: str, request_json: str, response_json: str) -> bool:
143
+ """Set entry in cache.
144
+
145
+ Args:
146
+ key: Cache key
147
+ request_json: JSON string of request parameters
148
+ response_json: JSON string of response
149
+
150
+ Returns:
151
+ True if successful, False otherwise
152
+ """
153
+ with self.lock:
154
+ try:
155
+ with sqlite3.connect(self.db_path) as conn:
156
+ cursor = conn.cursor()
157
+ cursor.execute(
158
+ "INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
159
+ (key, request_json, response_json),
160
+ )
161
+ conn.commit()
162
+ logger.debug(f"Saved response to cache with key: {key}, response: {response_json}")
163
+ return True
164
+ except Exception as e:
165
+ logger.error(f"Failed to save to SQLite cache: {e}")
166
+ return False
167
+
168
+ def get_all_entries(self) -> dict[str, dict[str, Any]]:
169
+ """Get all cache entries from the database."""
170
+ cache = {}
171
+ try:
172
+ with sqlite3.connect(self.db_path) as conn:
173
+ conn.row_factory = sqlite3.Row
174
+ cursor = conn.cursor()
175
+ cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at")
176
+
177
+ for row in cursor.fetchall():
178
+ cache[row["key"]] = {
179
+ "request": row["request_json"],
180
+ "response": row["response_json"],
181
+ }
182
+
183
+ logger.debug(f"Retrieved {len(cache)} entries from cache database")
184
+ return cache
185
+ except Exception as e:
186
+ logger.error(f"Error retrieving all cache entries: {e}")
187
+ return {}
188
+
189
+ def clear(self) -> bool:
190
+ """Clear all cache entries.
191
+
192
+ Returns:
193
+ True if successful, False otherwise
194
+ """
195
+ with self.lock:
196
+ try:
197
+ with sqlite3.connect(self.db_path) as conn:
198
+ cursor = conn.cursor()
199
+ cursor.execute("DELETE FROM llm_cache")
200
+ conn.commit()
201
+ logger.info("Cache cleared")
202
+ return True
203
+ except Exception as e:
204
+ logger.error(f"Failed to clear cache: {e}")
205
+ return False
206
+
207
+ def get_existing_keys(self) -> set:
208
+ """Get all existing keys in the database.
209
+
210
+ Returns:
211
+ Set of keys
212
+ """
213
+ existing_keys = set()
214
+ try:
215
+ with sqlite3.connect(self.db_path) as conn:
216
+ cursor = conn.cursor()
217
+ cursor.execute("SELECT key FROM llm_cache")
218
+ for row in cursor.fetchall():
219
+ existing_keys.add(row[0])
220
+ return existing_keys
221
+ except Exception as e:
222
+ logger.error(f"Error retrieving existing keys: {e}")
223
+ return set()
224
+
225
+ def bulk_insert(self, items: list, update: bool = False) -> int:
226
+ """Insert multiple items into the cache.
227
+
228
+ Args:
229
+ items: List of (key, request_json, response_json) tuples
230
+ update: Whether to update existing entries
231
+
232
+ Returns:
233
+ Number of items inserted
234
+ """
235
+ count = 0
236
+ UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE"
237
+ with self.lock:
238
+ try:
239
+ with sqlite3.connect(self.db_path) as conn:
240
+ cursor = conn.cursor()
241
+ cursor.executemany(
242
+ f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
243
+ items,
244
+ )
245
+ count = cursor.rowcount
246
+ conn.commit()
247
+ return count
248
+ except Exception as e:
249
+ logger.error(f"Error during bulk insert: {e}")
250
+ return 0
251
+
252
+
253
+ class LLMCache:
254
+ def __init__(
255
+ self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False
256
+ ):
257
+ self.cache_dir = Path(cache_dir)
258
+ self.db_path = self.cache_dir / "llm_cache.db"
259
+ self.hf_repo_id = hf_repo
260
+ self.cache_sync_interval = cache_sync_interval
261
+ self.last_sync_time = time.time()
262
+
263
+ # Create cache directory if it doesn't exist
264
+ self.cache_dir.mkdir(exist_ok=True, parents=True)
265
+
266
+ # Initialize CacheDB
267
+ self.db = CacheDB(self.db_path)
268
+ if reset:
269
+ self.db.clear()
270
+
271
+ # Try to load from HF dataset if available
272
+ try:
273
+ self._load_cache_from_hf()
274
+ except Exception as e:
275
+ logger.warning(f"Failed to load cache from HF dataset: {e}")
276
+
277
+ def response_format_to_dict(self, response_format: Any) -> dict[str, Any]:
278
+ """Convert a response format to a dict."""
279
+ # If it's a Pydantic model, use its schema
280
+ if hasattr(response_format, "model_json_schema"):
281
+ response_format = response_format.model_json_schema()
282
+
283
+ # If it's a Pydantic model, use its dump
284
+ elif hasattr(response_format, "model_dump"):
285
+ response_format = response_format.model_dump()
286
+
287
+ if not isinstance(response_format, dict):
288
+ response_format = {"value": str(response_format)}
289
+
290
+ return response_format
291
+
292
+ def _generate_key(
293
+ self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None
294
+ ) -> str:
295
+ """Generate a unique key for caching based on inputs."""
296
+ response_format_dict = self.response_format_to_dict(response_format)
297
+ response_format_str = json.dumps(response_format_dict, sort_keys=True)
298
+ # Include temperature in the key
299
+ key_content = f"{model}:{system}:{prompt}:{response_format_str}"
300
+ if temperature is not None:
301
+ key_content += f":{temperature:.2f}"
302
+ return hashlib.md5(key_content.encode()).hexdigest()
303
+
304
+ def _create_request_json(
305
+ self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None
306
+ ) -> str:
307
+ """Create JSON string from request parameters."""
308
+ logger.info(f"Creating request JSON with temperature: {temperature}")
309
+ request_data = {
310
+ "model": model,
311
+ "system": system,
312
+ "prompt": prompt,
313
+ "response_format": self.response_format_to_dict(response_format),
314
+ "temperature": temperature,
315
+ }
316
+ return json.dumps(request_data)
317
+
318
+ def _check_request_match(
319
+ self,
320
+ cached_request: dict[str, Any],
321
+ model: str,
322
+ system: str,
323
+ prompt: str,
324
+ response_format: Any,
325
+ temperature: float | None,
326
+ ) -> bool:
327
+ """Check if the cached request matches the new request."""
328
+ # Check each field and log any mismatches
329
+ if cached_request["model"] != model:
330
+ logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}")
331
+ return False
332
+ if cached_request["system"] != system:
333
+ logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}")
334
+ return False
335
+ if cached_request["prompt"] != prompt:
336
+ logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}")
337
+ return False
338
+ response_format_dict = self.response_format_to_dict(response_format)
339
+ if cached_request["response_format"] != response_format_dict:
340
+ logger.debug(
341
+ f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}"
342
+ )
343
+ return False
344
+ if cached_request["temperature"] != temperature:
345
+ logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}")
346
+ return False
347
+
348
+ return True
349
+
350
+ def get(
351
+ self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None
352
+ ) -> Optional[dict[str, Any]]:
353
+ """Get cached response if it exists."""
354
+ key = self._generate_key(model, system, prompt, response_format, temperature)
355
+ result = self.db.get(key)
356
+
357
+ if not result:
358
+ return None
359
+ request_dict = json.loads(result["request"])
360
+ if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature):
361
+ logger.warning(f"Cached request does not match new request for key: {key}")
362
+ return None
363
+
364
+ return json.loads(result["response"])
365
+
366
+ def set(
367
+ self,
368
+ model: str,
369
+ system: str,
370
+ prompt: str,
371
+ response_format: dict[str, Any],
372
+ temperature: float | None,
373
+ response: dict[str, Any],
374
+ ) -> None:
375
+ """Set response in cache and sync if needed."""
376
+ key = self._generate_key(model, system, prompt, response_format, temperature)
377
+ request_json = self._create_request_json(model, system, prompt, response_format, temperature)
378
+ response_json = json.dumps(response)
379
+
380
+ success = self.db.set(key, request_json, response_json)
381
+
382
+ # Check if we should sync to HF
383
+ if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval):
384
+ try:
385
+ self.sync_to_hf()
386
+ self.last_sync_time = time.time()
387
+ except Exception as e:
388
+ logger.error(f"Failed to sync cache to HF dataset: {e}")
389
+
390
+ def _load_cache_from_hf(self) -> None:
391
+ """Load cache from HF dataset if it exists and merge with local cache."""
392
+ if not self.hf_repo_id:
393
+ return
394
+
395
+ try:
396
+ # Check for new commits before loading the dataset
397
+ dataset = load_dataset_from_hf(self.hf_repo_id, self.cache_dir / "hf_cache")
398
+ if dataset:
399
+ existing_keys = self.db.get_existing_keys()
400
+
401
+ # Prepare batch items for insertion
402
+ items_to_insert = []
403
+ for item in dataset:
404
+ key = item["key"]
405
+ # Only update if not in local cache to prioritize local changes
406
+ if key in existing_keys:
407
+ continue
408
+ # Create request JSON
409
+ request_data = {
410
+ "model": item["model"],
411
+ "system": item["system"],
412
+ "prompt": item["prompt"],
413
+ "temperature": item["temperature"],
414
+ "response_format": None, # We can't fully reconstruct this
415
+ }
416
+
417
+ items_to_insert.append(
418
+ (
419
+ key,
420
+ json.dumps(request_data),
421
+ item["response"], # This is already a JSON string
422
+ )
423
+ )
424
+ logger.info(
425
+ f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}"
426
+ )
427
+
428
+ # Bulk insert new items
429
+ if items_to_insert:
430
+ inserted_count = self.db.bulk_insert(items_to_insert)
431
+ logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache")
432
+ else:
433
+ logger.info("No new items to merge from HF dataset")
434
+ except Exception as e:
435
+ logger.warning(f"Could not load cache from HF dataset: {e}")
436
+
437
+ def get_all_entries(self) -> dict[str, dict[str, Any]]:
438
+ """Get all cache entries from the database."""
439
+ cache = self.db.get_all_entries()
440
+ entries = {}
441
+ for key, entry in cache.items():
442
+ request = json.loads(entry["request"])
443
+ response = json.loads(entry["response"])
444
+ entries[key] = {"request": request, "response": response}
445
+ return entries
446
+
447
+ def sync_to_hf(self) -> None:
448
+ """Sync cache to HF dataset."""
449
+ if not self.hf_repo_id:
450
+ return
451
+
452
+ # Get all entries from the database
453
+ cache = self.db.get_all_entries()
454
+
455
+ # Convert cache to dataset format
456
+ entries = []
457
+ for key, entry in cache.items():
458
+ request = json.loads(entry["request"])
459
+ response_str = entry["response"]
460
+ entries.append(
461
+ {
462
+ "key": key,
463
+ "model": request["model"],
464
+ "system": request["system"],
465
+ "prompt": request["prompt"],
466
+ "response_format": request["response_format"],
467
+ "temperature": request["temperature"],
468
+ "response": response_str,
469
+ }
470
+ )
471
+
472
+ # Create and push dataset
473
+ dataset = Dataset.from_list(entries)
474
+ dataset.push_to_hub(self.hf_repo_id, private=True)
475
+ logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}")
476
+
477
+ def clear(self) -> None:
478
+ """Clear all cache entries."""
479
+ self.db.clear()
src/workflows/llms.py CHANGED
@@ -1,12 +1,14 @@
1
  # %%
 
2
  import json
3
  import os
4
- from typing import Optional
5
 
6
  import cohere
7
  import numpy as np
8
  from langchain_anthropic import ChatAnthropic
9
  from langchain_cohere import ChatCohere
 
10
  from langchain_openai import ChatOpenAI
11
  from loguru import logger
12
  from openai import OpenAI
@@ -14,6 +16,10 @@ from pydantic import BaseModel, Field
14
  from rich import print as rprint
15
 
16
  from .configs import AVAILABLE_MODELS
 
 
 
 
17
 
18
 
19
  def _openai_is_json_mode_supported(model_name: str) -> bool:
@@ -30,7 +36,7 @@ class LLMOutput(BaseModel):
30
  logprob: Optional[float] = Field(None, description="The log probability of the response")
31
 
32
 
33
- def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
34
  output = llm.invoke([("system", system), ("human", prompt)])
35
  ai_message = output["raw"]
36
  content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
@@ -38,7 +44,9 @@ def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
38
  return {"content": content_str, "output": output["parsed"].model_dump()}
39
 
40
 
41
- def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
 
 
42
  messages = [
43
  {"role": "system", "content": system},
44
  {"role": "user", "content": prompt},
@@ -49,6 +57,7 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
49
  messages=messages,
50
  response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
51
  logprobs=logprobs,
 
52
  )
53
  output = {}
54
  output["content"] = response.message.content[0].text
@@ -59,12 +68,16 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
59
  return output
60
 
61
 
62
- def _openai_langchain_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
63
- llm = ChatOpenAI(model=model).with_structured_output(response_model, include_raw=True)
 
 
64
  return _get_langchain_chat_output(llm, system, prompt)
65
 
66
 
67
- def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
 
 
68
  messages = [
69
  {"role": "system", "content": system},
70
  {"role": "user", "content": prompt},
@@ -75,6 +88,7 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
75
  messages=messages,
76
  response_format=response_model,
77
  logprobs=logprobs,
 
78
  )
79
  output = {}
80
  output["content"] = response.choices[0].message.content
@@ -85,14 +99,18 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
85
  return output
86
 
87
 
88
- def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
89
- llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
 
 
90
  return _get_langchain_chat_output(llm, system, prompt)
91
 
92
 
93
- def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
 
 
94
  """
95
- Generate a completion from an LLM provider with structured output.
96
 
97
  Args:
98
  model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
@@ -116,20 +134,69 @@ def completion(model: str, system: str, prompt: str, response_format, logprobs:
116
  model_name = AVAILABLE_MODELS[model]["model"]
117
  provider = model.split("/")[0]
118
  if provider == "Cohere":
119
- return _cohere_completion(model_name, system, prompt, response_format, logprobs)
120
  elif provider == "OpenAI":
121
  if _openai_is_json_mode_supported(model_name):
122
- return _openai_completion(model_name, system, prompt, response_format, logprobs)
 
 
123
  else:
124
- return _openai_langchain_completion(model_name, system, prompt, response_format, logprobs)
125
  elif provider == "Anthropic":
126
  if logprobs:
127
- raise ValueError("Anthropic does not support logprobs")
128
- return _anthropic_completion(model_name, system, prompt, response_format)
129
  else:
130
  raise ValueError(f"Provider {provider} not supported")
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # %%
134
  if __name__ == "__main__":
135
  from tqdm import tqdm
@@ -142,12 +209,52 @@ if __name__ == "__main__":
142
  answer: str = Field(description="The short answer to the question")
143
  explanation: str = Field(description="5 words terse best explanation of the answer.")
144
 
145
- models = AVAILABLE_MODELS.keys()
146
  system = "You are an accurate and concise explainer of scientific concepts."
147
  prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
148
 
 
 
 
 
 
 
 
 
 
 
149
  for model in tqdm(models):
150
  response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
151
  rprint(response)
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # %%
 
1
  # %%
2
+
3
  import json
4
  import os
5
+ from typing import Any, Optional
6
 
7
  import cohere
8
  import numpy as np
9
  from langchain_anthropic import ChatAnthropic
10
  from langchain_cohere import ChatCohere
11
+ from langchain_core.language_models import BaseChatModel
12
  from langchain_openai import ChatOpenAI
13
  from loguru import logger
14
  from openai import OpenAI
 
16
  from rich import print as rprint
17
 
18
  from .configs import AVAILABLE_MODELS
19
+ from .llmcache import LLMCache
20
+
21
+ # Initialize global cache
22
+ llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache")
23
 
24
 
25
  def _openai_is_json_mode_supported(model_name: str) -> bool:
 
36
  logprob: Optional[float] = Field(None, description="The log probability of the response")
37
 
38
 
39
+ def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
40
  output = llm.invoke([("system", system), ("human", prompt)])
41
  ai_message = output["raw"]
42
  content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
 
44
  return {"content": content_str, "output": output["parsed"].model_dump()}
45
 
46
 
47
+ def _cohere_completion(
48
+ model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
49
+ ) -> str:
50
  messages = [
51
  {"role": "system", "content": system},
52
  {"role": "user", "content": prompt},
 
57
  messages=messages,
58
  response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
59
  logprobs=logprobs,
60
+ temperature=temperature,
61
  )
62
  output = {}
63
  output["content"] = response.message.content[0].text
 
68
  return output
69
 
70
 
71
+ def _openai_langchain_completion(
72
+ model: str, system: str, prompt: str, response_model, temperature: float | None = None
73
+ ) -> str:
74
+ llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
75
  return _get_langchain_chat_output(llm, system, prompt)
76
 
77
 
78
+ def _openai_completion(
79
+ model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
80
+ ) -> str:
81
  messages = [
82
  {"role": "system", "content": system},
83
  {"role": "user", "content": prompt},
 
88
  messages=messages,
89
  response_format=response_model,
90
  logprobs=logprobs,
91
+ temperature=temperature,
92
  )
93
  output = {}
94
  output["content"] = response.choices[0].message.content
 
99
  return output
100
 
101
 
102
+ def _anthropic_completion(
103
+ model: str, system: str, prompt: str, response_model, temperature: float | None = None
104
+ ) -> str:
105
+ llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
106
  return _get_langchain_chat_output(llm, system, prompt)
107
 
108
 
109
+ def _llm_completion(
110
+ model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
111
+ ) -> dict[str, Any]:
112
  """
113
+ Generate a completion from an LLM provider with structured output without caching.
114
 
115
  Args:
116
  model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
 
134
  model_name = AVAILABLE_MODELS[model]["model"]
135
  provider = model.split("/")[0]
136
  if provider == "Cohere":
137
+ return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
138
  elif provider == "OpenAI":
139
  if _openai_is_json_mode_supported(model_name):
140
+ return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
141
+ elif logprobs:
142
+ raise ValueError(f"{model} does not support logprobs feature.")
143
  else:
144
+ return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
145
  elif provider == "Anthropic":
146
  if logprobs:
147
+ raise ValueError("Anthropic models do not support logprobs")
148
+ return _anthropic_completion(model_name, system, prompt, response_format, temperature)
149
  else:
150
  raise ValueError(f"Provider {provider} not supported")
151
 
152
 
153
+ def completion(
154
+ model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
155
+ ) -> dict[str, Any]:
156
+ """
157
+ Generate a completion from an LLM provider with structured output with caching.
158
+
159
+ Args:
160
+ model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
161
+ system (str): System prompt/instructions for the model
162
+ prompt (str): User prompt/input
163
+ response_format: Pydantic model defining the expected response structure
164
+ logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
165
+ Note: Not supported by Anthropic models.
166
+
167
+ Returns:
168
+ dict: Contains:
169
+ - output: The structured response matching response_format
170
+ - logprob: (optional) Sum of log probabilities if logprobs=True
171
+ - prob: (optional) Exponential of logprob if logprobs=True
172
+
173
+ Raises:
174
+ ValueError: If logprobs=True with Anthropic models
175
+ """
176
+ # Check cache first
177
+ cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
178
+ if cached_response is not None:
179
+ logger.info(f"Cache hit for model {model}")
180
+ return cached_response
181
+
182
+ logger.info(f"Cache miss for model {model}, calling API")
183
+
184
+ # Continue with the original implementation for cache miss
185
+ response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
186
+
187
+ # Update cache with the new response
188
+ llm_cache.set(
189
+ model,
190
+ system,
191
+ prompt,
192
+ response_format,
193
+ temperature,
194
+ response,
195
+ )
196
+
197
+ return response
198
+
199
+
200
  # %%
201
  if __name__ == "__main__":
202
  from tqdm import tqdm
 
209
  answer: str = Field(description="The short answer to the question")
210
  explanation: str = Field(description="5 words terse best explanation of the answer.")
211
 
212
+ models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
213
  system = "You are an accurate and concise explainer of scientific concepts."
214
  prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
215
 
216
+ llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True)
217
+
218
+ # First call - should be a cache miss
219
+ logger.info("First call - should be a cache miss")
220
+ for model in tqdm(models):
221
+ response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
222
+ rprint(response)
223
+
224
+ # Second call - should be a cache hit
225
+ logger.info("Second call - should be a cache hit")
226
  for model in tqdm(models):
227
  response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
228
  rprint(response)
229
 
230
+ # Slightly different prompt - should be a cache miss
231
+ logger.info("Different prompt - should be a cache miss")
232
+ prompt2 = "Which planet is closest to the sun? Answer directly."
233
+ for model in tqdm(models):
234
+ response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
235
+ rprint(response)
236
+
237
+ # Get cache entries count from SQLite
238
+ try:
239
+ cache_entries = llm_cache.get_all_entries()
240
+ logger.info(f"Cache now has {len(cache_entries)} items")
241
+ except Exception as e:
242
+ logger.error(f"Failed to get cache entries: {e}")
243
+
244
+ # Test adding entry with temperature parameter
245
+ logger.info("Testing with temperature parameter")
246
+ response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
247
+ rprint(response)
248
+
249
+ # Demonstrate forced sync to HF if repo is configured
250
+ if llm_cache.hf_repo_id:
251
+ logger.info("Forcing sync to HF dataset")
252
+ try:
253
+ llm_cache.sync_to_hf()
254
+ logger.info("Successfully synced to HF dataset")
255
+ except Exception as e:
256
+ logger.exception(f"Failed to sync to HF: {e}")
257
+ else:
258
+ logger.info("HF repo not configured, skipping sync test")
259
+
260
  # %%