Spaces:
Running
Running
Maharshi Gor
Adds support for caching llm calls to a sqlite db and a hf dataset. Refactors repo creation logic and fixes unused temperature param.
3a1af80
import hashlib | |
import json | |
import os | |
import sqlite3 | |
import threading | |
import time | |
from pathlib import Path | |
from typing import Any, Optional | |
from datasets import Dataset, Features, Value | |
from huggingface_hub import snapshot_download | |
from loguru import logger | |
def load_dataset_from_hf(repo_id, local_dir): | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
repo_type="dataset", | |
tqdm_class=None, | |
etag_timeout=30, | |
token=os.environ["HF_TOKEN"], | |
) | |
class CacheDB: | |
"""Handles database operations for storing and retrieving cache entries.""" | |
def __init__(self, db_path: Path): | |
"""Initialize database connection. | |
Args: | |
db_path: Path to SQLite database file | |
""" | |
self.db_path = db_path | |
self.lock = threading.Lock() | |
# Initialize the database | |
try: | |
self.initialize_db() | |
except Exception as e: | |
logger.exception(f"Failed to initialize database: {e}") | |
logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}") | |
raise | |
def initialize_db(self) -> None: | |
"""Initialize SQLite database with the required table.""" | |
# Check if database file already exists | |
if self.db_path.exists(): | |
self._verify_existing_db() | |
else: | |
self._create_new_db() | |
def _verify_existing_db(self) -> None: | |
"""Verify and repair an existing database if needed.""" | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
self._ensure_table_exists(cursor) | |
self._verify_schema(cursor) | |
self._ensure_index_exists(cursor) | |
conn.commit() | |
logger.info(f"Using existing SQLite database at {self.db_path}") | |
except Exception as e: | |
logger.exception(f"Database corruption detected: {e}") | |
raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}") | |
def _create_new_db(self) -> None: | |
"""Create a new database with the required schema.""" | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
self._create_table(cursor) | |
self._ensure_index_exists(cursor) | |
conn.commit() | |
logger.info(f"Initialized new SQLite database at {self.db_path}") | |
except Exception as e: | |
logger.exception(f"Failed to initialize SQLite database: {e}") | |
raise | |
def _ensure_table_exists(self, cursor) -> None: | |
"""Check if the llm_cache table exists and create it if not.""" | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'") | |
if not cursor.fetchone(): | |
self._create_table(cursor) | |
logger.info("Created missing llm_cache table") | |
def _create_table(self, cursor) -> None: | |
"""Create the llm_cache table with the required schema.""" | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS llm_cache ( | |
key TEXT PRIMARY KEY, | |
request_json TEXT, | |
response_json TEXT, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
""") | |
def _verify_schema(self, cursor) -> None: | |
"""Verify that the table schema has all required columns.""" | |
cursor.execute("PRAGMA table_info(llm_cache)") | |
columns = {row[1] for row in cursor.fetchall()} | |
required_columns = {"key", "request_json", "response_json", "created_at"} | |
if not required_columns.issubset(columns): | |
missing = required_columns - columns | |
raise ValueError(f"Database schema is corrupted. Missing columns: {missing}") | |
def _ensure_index_exists(self, cursor) -> None: | |
"""Create an index on the key column for faster lookups.""" | |
cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)") | |
def get(self, key: str) -> Optional[dict[str, Any]]: | |
"""Get cached entry by key. | |
Args: | |
key: Cache key to look up | |
Returns: | |
Dict containing the request and response or None if not found | |
""" | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
conn.row_factory = sqlite3.Row | |
cursor = conn.cursor() | |
cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,)) | |
result = cursor.fetchone() | |
if result: | |
logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}") | |
return { | |
"request": result["request_json"], | |
"response": result["response_json"], | |
} | |
logger.debug(f"Cache miss for key: {key}") | |
return None | |
except Exception as e: | |
logger.error(f"Error retrieving from cache: {e}") | |
return None | |
def set(self, key: str, request_json: str, response_json: str) -> bool: | |
"""Set entry in cache. | |
Args: | |
key: Cache key | |
request_json: JSON string of request parameters | |
response_json: JSON string of response | |
Returns: | |
True if successful, False otherwise | |
""" | |
with self.lock: | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)", | |
(key, request_json, response_json), | |
) | |
conn.commit() | |
logger.debug(f"Saved response to cache with key: {key}, response: {response_json}") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to save to SQLite cache: {e}") | |
return False | |
def get_all_entries(self) -> dict[str, dict[str, Any]]: | |
"""Get all cache entries from the database.""" | |
cache = {} | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
conn.row_factory = sqlite3.Row | |
cursor = conn.cursor() | |
cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at") | |
for row in cursor.fetchall(): | |
cache[row["key"]] = { | |
"request": row["request_json"], | |
"response": row["response_json"], | |
} | |
logger.debug(f"Retrieved {len(cache)} entries from cache database") | |
return cache | |
except Exception as e: | |
logger.error(f"Error retrieving all cache entries: {e}") | |
return {} | |
def clear(self) -> bool: | |
"""Clear all cache entries. | |
Returns: | |
True if successful, False otherwise | |
""" | |
with self.lock: | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute("DELETE FROM llm_cache") | |
conn.commit() | |
logger.info("Cache cleared") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to clear cache: {e}") | |
return False | |
def get_existing_keys(self) -> set: | |
"""Get all existing keys in the database. | |
Returns: | |
Set of keys | |
""" | |
existing_keys = set() | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT key FROM llm_cache") | |
for row in cursor.fetchall(): | |
existing_keys.add(row[0]) | |
return existing_keys | |
except Exception as e: | |
logger.error(f"Error retrieving existing keys: {e}") | |
return set() | |
def bulk_insert(self, items: list, update: bool = False) -> int: | |
"""Insert multiple items into the cache. | |
Args: | |
items: List of (key, request_json, response_json) tuples | |
update: Whether to update existing entries | |
Returns: | |
Number of items inserted | |
""" | |
count = 0 | |
UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE" | |
with self.lock: | |
try: | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.executemany( | |
f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)", | |
items, | |
) | |
count = cursor.rowcount | |
conn.commit() | |
return count | |
except Exception as e: | |
logger.error(f"Error during bulk insert: {e}") | |
return 0 | |
class LLMCache: | |
def __init__( | |
self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False | |
): | |
self.cache_dir = Path(cache_dir) | |
self.db_path = self.cache_dir / "llm_cache.db" | |
self.hf_repo_id = hf_repo | |
self.cache_sync_interval = cache_sync_interval | |
self.last_sync_time = time.time() | |
# Create cache directory if it doesn't exist | |
self.cache_dir.mkdir(exist_ok=True, parents=True) | |
# Initialize CacheDB | |
self.db = CacheDB(self.db_path) | |
if reset: | |
self.db.clear() | |
# Try to load from HF dataset if available | |
try: | |
self._load_cache_from_hf() | |
except Exception as e: | |
logger.warning(f"Failed to load cache from HF dataset: {e}") | |
def response_format_to_dict(self, response_format: Any) -> dict[str, Any]: | |
"""Convert a response format to a dict.""" | |
# If it's a Pydantic model, use its schema | |
if hasattr(response_format, "model_json_schema"): | |
response_format = response_format.model_json_schema() | |
# If it's a Pydantic model, use its dump | |
elif hasattr(response_format, "model_dump"): | |
response_format = response_format.model_dump() | |
if not isinstance(response_format, dict): | |
response_format = {"value": str(response_format)} | |
return response_format | |
def _generate_key( | |
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None | |
) -> str: | |
"""Generate a unique key for caching based on inputs.""" | |
response_format_dict = self.response_format_to_dict(response_format) | |
response_format_str = json.dumps(response_format_dict, sort_keys=True) | |
# Include temperature in the key | |
key_content = f"{model}:{system}:{prompt}:{response_format_str}" | |
if temperature is not None: | |
key_content += f":{temperature:.2f}" | |
return hashlib.md5(key_content.encode()).hexdigest() | |
def _create_request_json( | |
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None | |
) -> str: | |
"""Create JSON string from request parameters.""" | |
logger.info(f"Creating request JSON with temperature: {temperature}") | |
request_data = { | |
"model": model, | |
"system": system, | |
"prompt": prompt, | |
"response_format": self.response_format_to_dict(response_format), | |
"temperature": temperature, | |
} | |
return json.dumps(request_data) | |
def _check_request_match( | |
self, | |
cached_request: dict[str, Any], | |
model: str, | |
system: str, | |
prompt: str, | |
response_format: Any, | |
temperature: float | None, | |
) -> bool: | |
"""Check if the cached request matches the new request.""" | |
# Check each field and log any mismatches | |
if cached_request["model"] != model: | |
logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}") | |
return False | |
if cached_request["system"] != system: | |
logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}") | |
return False | |
if cached_request["prompt"] != prompt: | |
logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}") | |
return False | |
response_format_dict = self.response_format_to_dict(response_format) | |
if cached_request["response_format"] != response_format_dict: | |
logger.debug( | |
f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}" | |
) | |
return False | |
if cached_request["temperature"] != temperature: | |
logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}") | |
return False | |
return True | |
def get( | |
self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None | |
) -> Optional[dict[str, Any]]: | |
"""Get cached response if it exists.""" | |
key = self._generate_key(model, system, prompt, response_format, temperature) | |
result = self.db.get(key) | |
if not result: | |
return None | |
request_dict = json.loads(result["request"]) | |
if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature): | |
logger.warning(f"Cached request does not match new request for key: {key}") | |
return None | |
return json.loads(result["response"]) | |
def set( | |
self, | |
model: str, | |
system: str, | |
prompt: str, | |
response_format: dict[str, Any], | |
temperature: float | None, | |
response: dict[str, Any], | |
) -> None: | |
"""Set response in cache and sync if needed.""" | |
key = self._generate_key(model, system, prompt, response_format, temperature) | |
request_json = self._create_request_json(model, system, prompt, response_format, temperature) | |
response_json = json.dumps(response) | |
success = self.db.set(key, request_json, response_json) | |
# Check if we should sync to HF | |
if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval): | |
try: | |
self.sync_to_hf() | |
self.last_sync_time = time.time() | |
except Exception as e: | |
logger.error(f"Failed to sync cache to HF dataset: {e}") | |
def _load_cache_from_hf(self) -> None: | |
"""Load cache from HF dataset if it exists and merge with local cache.""" | |
if not self.hf_repo_id: | |
return | |
try: | |
# Check for new commits before loading the dataset | |
dataset = load_dataset_from_hf(self.hf_repo_id, self.cache_dir / "hf_cache") | |
if dataset: | |
existing_keys = self.db.get_existing_keys() | |
# Prepare batch items for insertion | |
items_to_insert = [] | |
for item in dataset: | |
key = item["key"] | |
# Only update if not in local cache to prioritize local changes | |
if key in existing_keys: | |
continue | |
# Create request JSON | |
request_data = { | |
"model": item["model"], | |
"system": item["system"], | |
"prompt": item["prompt"], | |
"temperature": item["temperature"], | |
"response_format": None, # We can't fully reconstruct this | |
} | |
items_to_insert.append( | |
( | |
key, | |
json.dumps(request_data), | |
item["response"], # This is already a JSON string | |
) | |
) | |
logger.info( | |
f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}" | |
) | |
# Bulk insert new items | |
if items_to_insert: | |
inserted_count = self.db.bulk_insert(items_to_insert) | |
logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache") | |
else: | |
logger.info("No new items to merge from HF dataset") | |
except Exception as e: | |
logger.warning(f"Could not load cache from HF dataset: {e}") | |
def get_all_entries(self) -> dict[str, dict[str, Any]]: | |
"""Get all cache entries from the database.""" | |
cache = self.db.get_all_entries() | |
entries = {} | |
for key, entry in cache.items(): | |
request = json.loads(entry["request"]) | |
response = json.loads(entry["response"]) | |
entries[key] = {"request": request, "response": response} | |
return entries | |
def sync_to_hf(self) -> None: | |
"""Sync cache to HF dataset.""" | |
if not self.hf_repo_id: | |
return | |
# Get all entries from the database | |
cache = self.db.get_all_entries() | |
# Convert cache to dataset format | |
entries = [] | |
for key, entry in cache.items(): | |
request = json.loads(entry["request"]) | |
response_str = entry["response"] | |
entries.append( | |
{ | |
"key": key, | |
"model": request["model"], | |
"system": request["system"], | |
"prompt": request["prompt"], | |
"response_format": request["response_format"], | |
"temperature": request["temperature"], | |
"response": response_str, | |
} | |
) | |
# Create and push dataset | |
dataset = Dataset.from_list(entries) | |
dataset.push_to_hub(self.hf_repo_id, private=True) | |
logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}") | |
def clear(self) -> None: | |
"""Clear all cache entries.""" | |
self.db.clear() | |