File size: 18,458 Bytes
3a1af80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
import hashlib
import json
import os
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Optional

from datasets import Dataset, Features, Value
from huggingface_hub import snapshot_download
from loguru import logger


def load_dataset_from_hf(repo_id, local_dir):
    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        repo_type="dataset",
        tqdm_class=None,
        etag_timeout=30,
        token=os.environ["HF_TOKEN"],
    )


class CacheDB:
    """Handles database operations for storing and retrieving cache entries."""

    def __init__(self, db_path: Path):
        """Initialize database connection.

        Args:
            db_path: Path to SQLite database file
        """
        self.db_path = db_path
        self.lock = threading.Lock()

        # Initialize the database
        try:
            self.initialize_db()
        except Exception as e:
            logger.exception(f"Failed to initialize database: {e}")
            logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}")
            raise

    def initialize_db(self) -> None:
        """Initialize SQLite database with the required table."""
        # Check if database file already exists
        if self.db_path.exists():
            self._verify_existing_db()
        else:
            self._create_new_db()

    def _verify_existing_db(self) -> None:
        """Verify and repair an existing database if needed."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                self._ensure_table_exists(cursor)
                self._verify_schema(cursor)
                self._ensure_index_exists(cursor)
                conn.commit()
            logger.info(f"Using existing SQLite database at {self.db_path}")
        except Exception as e:
            logger.exception(f"Database corruption detected: {e}")
            raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}")

    def _create_new_db(self) -> None:
        """Create a new database with the required schema."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                self._create_table(cursor)
                self._ensure_index_exists(cursor)
                conn.commit()
                logger.info(f"Initialized new SQLite database at {self.db_path}")
        except Exception as e:
            logger.exception(f"Failed to initialize SQLite database: {e}")
            raise

    def _ensure_table_exists(self, cursor) -> None:
        """Check if the llm_cache table exists and create it if not."""
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'")
        if not cursor.fetchone():
            self._create_table(cursor)
            logger.info("Created missing llm_cache table")

    def _create_table(self, cursor) -> None:
        """Create the llm_cache table with the required schema."""
        cursor.execute("""
        CREATE TABLE IF NOT EXISTS llm_cache (
            key TEXT PRIMARY KEY,
            request_json TEXT,
            response_json TEXT,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
        """)

    def _verify_schema(self, cursor) -> None:
        """Verify that the table schema has all required columns."""
        cursor.execute("PRAGMA table_info(llm_cache)")
        columns = {row[1] for row in cursor.fetchall()}
        required_columns = {"key", "request_json", "response_json", "created_at"}

        if not required_columns.issubset(columns):
            missing = required_columns - columns
            raise ValueError(f"Database schema is corrupted. Missing columns: {missing}")

    def _ensure_index_exists(self, cursor) -> None:
        """Create an index on the key column for faster lookups."""
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)")

    def get(self, key: str) -> Optional[dict[str, Any]]:
        """Get cached entry by key.

        Args:
            key: Cache key to look up

        Returns:
            Dict containing the request and response or None if not found
        """
        try:
            with sqlite3.connect(self.db_path) as conn:
                conn.row_factory = sqlite3.Row
                cursor = conn.cursor()
                cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,))
                result = cursor.fetchone()

                if result:
                    logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}")
                    return {
                        "request": result["request_json"],
                        "response": result["response_json"],
                    }

                logger.debug(f"Cache miss for key: {key}")
                return None
        except Exception as e:
            logger.error(f"Error retrieving from cache: {e}")
            return None

    def set(self, key: str, request_json: str, response_json: str) -> bool:
        """Set entry in cache.

        Args:
            key: Cache key
            request_json: JSON string of request parameters
            response_json: JSON string of response

        Returns:
            True if successful, False otherwise
        """
        with self.lock:
            try:
                with sqlite3.connect(self.db_path) as conn:
                    cursor = conn.cursor()
                    cursor.execute(
                        "INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
                        (key, request_json, response_json),
                    )
                    conn.commit()
                    logger.debug(f"Saved response to cache with key: {key}, response: {response_json}")
                    return True
            except Exception as e:
                logger.error(f"Failed to save to SQLite cache: {e}")
                return False

    def get_all_entries(self) -> dict[str, dict[str, Any]]:
        """Get all cache entries from the database."""
        cache = {}
        try:
            with sqlite3.connect(self.db_path) as conn:
                conn.row_factory = sqlite3.Row
                cursor = conn.cursor()
                cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at")

                for row in cursor.fetchall():
                    cache[row["key"]] = {
                        "request": row["request_json"],
                        "response": row["response_json"],
                    }

                logger.debug(f"Retrieved {len(cache)} entries from cache database")
                return cache
        except Exception as e:
            logger.error(f"Error retrieving all cache entries: {e}")
            return {}

    def clear(self) -> bool:
        """Clear all cache entries.

        Returns:
            True if successful, False otherwise
        """
        with self.lock:
            try:
                with sqlite3.connect(self.db_path) as conn:
                    cursor = conn.cursor()
                    cursor.execute("DELETE FROM llm_cache")
                    conn.commit()
                    logger.info("Cache cleared")
                    return True
            except Exception as e:
                logger.error(f"Failed to clear cache: {e}")
                return False

    def get_existing_keys(self) -> set:
        """Get all existing keys in the database.

        Returns:
            Set of keys
        """
        existing_keys = set()
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                cursor.execute("SELECT key FROM llm_cache")
                for row in cursor.fetchall():
                    existing_keys.add(row[0])
                return existing_keys
        except Exception as e:
            logger.error(f"Error retrieving existing keys: {e}")
            return set()

    def bulk_insert(self, items: list, update: bool = False) -> int:
        """Insert multiple items into the cache.

        Args:
            items: List of (key, request_json, response_json) tuples
            update: Whether to update existing entries

        Returns:
            Number of items inserted
        """
        count = 0
        UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE"
        with self.lock:
            try:
                with sqlite3.connect(self.db_path) as conn:
                    cursor = conn.cursor()
                    cursor.executemany(
                        f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
                        items,
                    )
                    count = cursor.rowcount
                    conn.commit()
                return count
            except Exception as e:
                logger.error(f"Error during bulk insert: {e}")
                return 0


class LLMCache:
    def __init__(
        self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False
    ):
        self.cache_dir = Path(cache_dir)
        self.db_path = self.cache_dir / "llm_cache.db"
        self.hf_repo_id = hf_repo
        self.cache_sync_interval = cache_sync_interval
        self.last_sync_time = time.time()

        # Create cache directory if it doesn't exist
        self.cache_dir.mkdir(exist_ok=True, parents=True)

        # Initialize CacheDB
        self.db = CacheDB(self.db_path)
        if reset:
            self.db.clear()

        # Try to load from HF dataset if available
        try:
            self._load_cache_from_hf()
        except Exception as e:
            logger.warning(f"Failed to load cache from HF dataset: {e}")

    def response_format_to_dict(self, response_format: Any) -> dict[str, Any]:
        """Convert a response format to a dict."""
        # If it's a Pydantic model, use its schema
        if hasattr(response_format, "model_json_schema"):
            response_format = response_format.model_json_schema()

        # If it's a Pydantic model, use its dump
        elif hasattr(response_format, "model_dump"):
            response_format = response_format.model_dump()

        if not isinstance(response_format, dict):
            response_format = {"value": str(response_format)}

        return response_format

    def _generate_key(
        self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None
    ) -> str:
        """Generate a unique key for caching based on inputs."""
        response_format_dict = self.response_format_to_dict(response_format)
        response_format_str = json.dumps(response_format_dict, sort_keys=True)
        # Include temperature in the key
        key_content = f"{model}:{system}:{prompt}:{response_format_str}"
        if temperature is not None:
            key_content += f":{temperature:.2f}"
        return hashlib.md5(key_content.encode()).hexdigest()

    def _create_request_json(
        self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None
    ) -> str:
        """Create JSON string from request parameters."""
        logger.info(f"Creating request JSON with temperature: {temperature}")
        request_data = {
            "model": model,
            "system": system,
            "prompt": prompt,
            "response_format": self.response_format_to_dict(response_format),
            "temperature": temperature,
        }
        return json.dumps(request_data)

    def _check_request_match(
        self,
        cached_request: dict[str, Any],
        model: str,
        system: str,
        prompt: str,
        response_format: Any,
        temperature: float | None,
    ) -> bool:
        """Check if the cached request matches the new request."""
        # Check each field and log any mismatches
        if cached_request["model"] != model:
            logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}")
            return False
        if cached_request["system"] != system:
            logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}")
            return False
        if cached_request["prompt"] != prompt:
            logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}")
            return False
        response_format_dict = self.response_format_to_dict(response_format)
        if cached_request["response_format"] != response_format_dict:
            logger.debug(
                f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}"
            )
            return False
        if cached_request["temperature"] != temperature:
            logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}")
            return False

        return True

    def get(
        self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None
    ) -> Optional[dict[str, Any]]:
        """Get cached response if it exists."""
        key = self._generate_key(model, system, prompt, response_format, temperature)
        result = self.db.get(key)

        if not result:
            return None
        request_dict = json.loads(result["request"])
        if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature):
            logger.warning(f"Cached request does not match new request for key: {key}")
            return None

        return json.loads(result["response"])

    def set(
        self,
        model: str,
        system: str,
        prompt: str,
        response_format: dict[str, Any],
        temperature: float | None,
        response: dict[str, Any],
    ) -> None:
        """Set response in cache and sync if needed."""
        key = self._generate_key(model, system, prompt, response_format, temperature)
        request_json = self._create_request_json(model, system, prompt, response_format, temperature)
        response_json = json.dumps(response)

        success = self.db.set(key, request_json, response_json)

        # Check if we should sync to HF
        if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval):
            try:
                self.sync_to_hf()
                self.last_sync_time = time.time()
            except Exception as e:
                logger.error(f"Failed to sync cache to HF dataset: {e}")

    def _load_cache_from_hf(self) -> None:
        """Load cache from HF dataset if it exists and merge with local cache."""
        if not self.hf_repo_id:
            return

        try:
            # Check for new commits before loading the dataset
            dataset = load_dataset_from_hf(self.hf_repo_id, self.cache_dir / "hf_cache")
            if dataset:
                existing_keys = self.db.get_existing_keys()

                # Prepare batch items for insertion
                items_to_insert = []
                for item in dataset:
                    key = item["key"]
                    # Only update if not in local cache to prioritize local changes
                    if key in existing_keys:
                        continue
                    # Create request JSON
                    request_data = {
                        "model": item["model"],
                        "system": item["system"],
                        "prompt": item["prompt"],
                        "temperature": item["temperature"],
                        "response_format": None,  # We can't fully reconstruct this
                    }

                    items_to_insert.append(
                        (
                            key,
                            json.dumps(request_data),
                            item["response"],  # This is already a JSON string
                        )
                    )
                    logger.info(
                        f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}"
                    )

                # Bulk insert new items
                if items_to_insert:
                    inserted_count = self.db.bulk_insert(items_to_insert)
                    logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache")
                else:
                    logger.info("No new items to merge from HF dataset")
        except Exception as e:
            logger.warning(f"Could not load cache from HF dataset: {e}")

    def get_all_entries(self) -> dict[str, dict[str, Any]]:
        """Get all cache entries from the database."""
        cache = self.db.get_all_entries()
        entries = {}
        for key, entry in cache.items():
            request = json.loads(entry["request"])
            response = json.loads(entry["response"])
            entries[key] = {"request": request, "response": response}
        return entries

    def sync_to_hf(self) -> None:
        """Sync cache to HF dataset."""
        if not self.hf_repo_id:
            return

        # Get all entries from the database
        cache = self.db.get_all_entries()

        # Convert cache to dataset format
        entries = []
        for key, entry in cache.items():
            request = json.loads(entry["request"])
            response_str = entry["response"]
            entries.append(
                {
                    "key": key,
                    "model": request["model"],
                    "system": request["system"],
                    "prompt": request["prompt"],
                    "response_format": request["response_format"],
                    "temperature": request["temperature"],
                    "response": response_str,
                }
            )

        # Create and push dataset
        dataset = Dataset.from_list(entries)
        dataset.push_to_hub(self.hf_repo_id, private=True)
        logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}")

    def clear(self) -> None:
        """Clear all cache entries."""
        self.db.clear()