Spaces:
Runtime error
Runtime error
import os | |
import logging | |
from sqlalchemy import create_engine, Column, Integer, String, DateTime | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker, Session | |
from datetime import datetime | |
import pandas as pd | |
import numpy as np | |
from datasets import load_dataset | |
from rating_systems import compute_elo, compute_bootstrap_elo, get_median_elo_from_bootstrap | |
def is_running_in_space(): | |
return "SPACE_ID" in os.environ | |
if is_running_in_space(): | |
DATABASE_URL = "sqlite:///./data/hf-votes.db" | |
else: | |
DATABASE_URL = "sqlite:///./data/local2.db" | |
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
# Database model | |
class Vote(Base): | |
__tablename__ = "votes" | |
id = Column(Integer, primary_key=True, index=True) | |
image_id = Column(String, index=True) | |
model_a = Column(String) | |
model_b = Column(String) | |
winner = Column(String) | |
user_id = Column(String, index=True) | |
fpath_a = Column(String) | |
fpath_b = Column(String) | |
timestamp = Column(DateTime, default=datetime.utcnow) | |
Base.metadata.create_all(bind=engine) | |
# Dependency for database session | |
def get_db(): | |
db = SessionLocal() | |
try: | |
yield db | |
finally: | |
db.close() | |
def fill_database_once(dataset_name="bgsys/votes_datasets_test2"): | |
with SessionLocal() as db: | |
# Check if the database is already filled | |
if db.query(Vote).first() is None: | |
dataset = load_dataset(dataset_name) | |
for record in dataset['train']: | |
# Ensure the timestamp is a string | |
timestamp_str = record.get("timestamp", datetime.utcnow().isoformat()) | |
if not isinstance(timestamp_str, str): | |
timestamp_str = datetime.utcnow().isoformat() | |
vote_data = { | |
"image_id": record.get("image_id", ""), | |
"model_a": record.get("model_a", ""), | |
"model_b": record.get("model_b", ""), | |
"winner": record.get("winner", ""), | |
"user_id": record.get("user_id", ""), | |
"fpath_a": record.get("fpath_a", ""), | |
"fpath_b": record.get("fpath_b", ""), | |
"timestamp": datetime.fromisoformat(timestamp_str) | |
} | |
db_vote = Vote(**vote_data) | |
db.add(db_vote) | |
db.commit() | |
logging.info("Database filled with data from Hugging Face dataset: %s", dataset_name) | |
else: | |
logging.info("Database already filled, skipping dataset loading.") | |
def add_vote(vote_data): | |
with SessionLocal() as db: | |
db_vote = Vote(**vote_data) | |
db.add(db_vote) | |
db.commit() | |
db.refresh(db_vote) | |
logging.info("Vote registered with ID: %s, using database: %s", db_vote.id, DATABASE_URL) | |
return {"id": db_vote.id, "user_id": db_vote.user_id, "timestamp": db_vote.timestamp} | |
# Function to get all votes | |
def get_all_votes(): | |
with SessionLocal() as db: | |
votes = db.query(Vote).all() | |
return votes | |
# Function to compute Elo scores | |
def compute_elo_scores(): | |
valid_models = ["Photoroom", "RemoveBG", "BRIA RMBG 2.0", "Clipdrop", "BiRefNet v2"] | |
with SessionLocal() as db: | |
votes = db.query(Vote).all() | |
data = { | |
"model_a": [vote.model_a for vote in votes], | |
"model_b": [vote.model_b for vote in votes], | |
"winner": [vote.winner for vote in votes] | |
} | |
df = pd.DataFrame(data) | |
init_size = df.shape[0] | |
# Remove votes missing model_a, model_b or winner info | |
df.dropna(subset=["model_a", "model_b", "winner"], inplace=True) | |
# Validate models and winner | |
def is_valid_vote(row): | |
if row["model_a"] not in valid_models or row["model_b"] not in valid_models: | |
return False | |
if row["winner"] not in ["model_a", "model_b", "tie"]: | |
return False | |
return True | |
df = df[df.apply(is_valid_vote, axis=1)] | |
logging.info("Initial votes count: %d", init_size) | |
logging.info("Votes count after validation: %d", df.shape[0]) | |
# Seed the random number generator for reproducibility | |
np.random.seed(42) | |
bootstrap_elo_scores = compute_bootstrap_elo(df) | |
median_elo_scores = get_median_elo_from_bootstrap(bootstrap_elo_scores) | |
model_rating_q025 = bootstrap_elo_scores.quantile(0.025) | |
model_rating_q975 = bootstrap_elo_scores.quantile(0.975) | |
variance = bootstrap_elo_scores.var() | |
return median_elo_scores, model_rating_q025, model_rating_q975, variance | |
# Function to compute the number of votes for each model | |
def compute_votes_per_model(): | |
with SessionLocal() as db: | |
votes = db.query(Vote).all() | |
model_vote_count = {} | |
for vote in votes: | |
if vote.winner == "model_a": | |
model = vote.model_a | |
elif vote.winner == "model_b": | |
model = vote.model_b | |
else: | |
continue | |
if model not in model_vote_count: | |
model_vote_count[model] = 0 | |
model_vote_count[model] += 1 | |
return model_vote_count | |