Spaces:
Running
Running
import sqlite3 | |
from pathlib import Path | |
from typing import Any, Dict | |
from pydantic import BaseModel | |
class PersistentSettings(BaseModel): | |
""" | |
This pydantic model will try to initialize itself from | |
the database upon every instantiation | |
It further supplies an update function, that allows to write | |
back any changes into the database, under its key. | |
""" | |
class Config: | |
arbitrary_types_allowed = True # Exclude conn from Pydantic validation | |
def __init__(self, **data: Any): | |
# Connect to the SQLite database | |
self.conn = sqlite3.connect("config.db") | |
# Create a table for settings if it doesn't exist | |
self.conn.execute(""" | |
CREATE TABLE IF NOT EXISTS settings ( | |
key TEXT PRIMARY KEY, | |
value TEXT | |
) | |
""") | |
# Fetch settings from the database and initialize | |
super().__init__(**self.fetch_settings(), **data) | |
def fetch_settings(self) -> Dict[str, Any]: | |
""" | |
Retrieve settings from the database | |
""" | |
cursor = self.conn.cursor() | |
cursor.execute("SELECT key, value FROM settings") | |
settings = dict(cursor.fetchall()) | |
cursor.close() | |
return settings | |
def update(self, **data: Any) -> None: | |
""" | |
Persist the pydantic-dict that represents the model | |
""" | |
cursor = self.conn.cursor() | |
# Update or insert each key-value pair into the database | |
for key, value in {**self.dict(), **data}.items(): | |
cursor.execute( | |
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", | |
(key, value) | |
) | |
# Commit the changes to the database | |
self.conn.commit() | |
cursor.close() | |
def close(self) -> None: | |
""" | |
Close the database connection | |
""" | |
self.conn.close() | |
class TortoiseConfig(PersistentSettings): | |
EXTRA_VOICES_DIR: str = "" | |
AR_CHECKPOINT: str = "." | |
DIFF_CHECKPOINT: str = "." | |
LOW_VRAM: bool = True | |
conn: sqlite3.Connection = None | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
if not Path(self.AR_CHECKPOINT).is_file(): | |
self.AR_CHECKPOINT = "." | |
if not Path(self.DIFF_CHECKPOINT).is_file(): | |
self.DIFF_CHECKPOINT = "." | |