diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..aa96e1820a29689a9163b6a202ac15bf5d81fd8e --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +ENVIRONMENT=development +HF_TOKEN=xxx +HF_HOME=.cache diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..08e57889adbc8cb31f2809bb3232c4f42e283a21 --- /dev/null +++ b/.gitignore @@ -0,0 +1,45 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +__pycache__ +.cache/ + +# dependencies + +frontend/node_modules +/.pnp +.pnp.js + +# testing + +/coverage + +# production + +/build + +# misc + +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log\* + +src/dataframe.json + +yarn.lock +package-lock.json + +/public + +.claudesync/ + +# Environment variables +.env +.env.* +!.env.example + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f90da3272f5ea911250b43d7eaccf38b3a7b1412 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,62 @@ +# Build frontend +FROM node:18 as frontend-build +WORKDIR /app +COPY frontend/package*.json ./ +RUN npm install +COPY frontend/ ./ + +RUN npm run build + +# Build backend +FROM python:3.12-slim +WORKDIR /app + +# Create non-root user +RUN useradd -m -u 1000 user + +# Install poetry +RUN pip install poetry + +# Create and configure cache directory +RUN mkdir -p /app/.cache && \ + chown -R user:user /app + +# Copy and install backend dependencies +COPY backend/pyproject.toml backend/poetry.lock* ./ +RUN poetry config virtualenvs.create false \ + && poetry install --no-interaction --no-ansi --no-root --only main + +# Copy backend code +COPY backend/ . + +# Install Node.js and npm +RUN apt-get update && apt-get install -y \ + curl \ + netcat-openbsd \ + && curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ + && apt-get install -y nodejs \ + && rm -rf /var/lib/apt/lists/* + +# Copy frontend server and build +COPY --from=frontend-build /app/build ./frontend/build +COPY --from=frontend-build /app/package*.json ./frontend/ +COPY --from=frontend-build /app/server.js ./frontend/ + +# Install frontend production dependencies +WORKDIR /app/frontend +RUN npm install --production +WORKDIR /app + +# Environment variables +ENV HF_HOME=/app/.cache \ + HF_DATASETS_CACHE=/app/.cache \ + INTERNAL_API_PORT=7861 \ + PORT=7860 \ + NODE_ENV=production + +# Note: HF_TOKEN should be provided at runtime, not build time +USER user +EXPOSE 7860 + +# Start both servers with wait-for +CMD ["sh", "-c", "uvicorn app.asgi:app --host 0.0.0.0 --port 7861 & while ! nc -z localhost 7861; do sleep 1; done && cd frontend && npm run serve"] \ No newline at end of file diff --git a/README.md b/README.md index 1a82d98ee562926741eb0ecad6e97da379725e86..db797f57c118696c4e63a24d28d7dc6b02d2feec 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,91 @@ --- -title: Open Llm Leaderboard -emoji: 👀 -colorFrom: red -colorTo: blue +title: Open LLM Leaderboard +emoji: 🏆 +colorFrom: blue +colorTo: red sdk: docker -pinned: false +hf_oauth: true +pinned: true license: apache-2.0 -short_description: copy of official llm leaderboard +duplicated_from: open-llm-leaderboard/open_llm_leaderboard +tags: +- leaderboard +- modality:text +- submission:automatic +- test:public +- language:english +- eval:code +- eval:math +short_description: Track, rank and evaluate open LLMs and chatbots --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Open LLM Leaderboard + +Modern React interface for comparing Large Language Models (LLMs) in an open and reproducible way. + +## Features + +- 📊 Interactive table with advanced sorting and filtering +- 🔍 Semantic model search +- 📌 Pin models for comparison +- 📱 Responsive and modern interface +- 🎨 Dark/Light mode +- ⚡️ Optimized performance with virtualization + +## Architecture + +The project is split into two main parts: + +### Frontend (React) + +``` +frontend/ +├── src/ +│ ├── components/ # Reusable UI components +│ ├── pages/ # Application pages +│ ├── hooks/ # Custom React hooks +│ ├── context/ # React contexts +│ └── constants/ # Constants and configurations +├── public/ # Static assets +└── server.js # Express server for production +``` + +### Backend (FastAPI) + +``` +backend/ +├── app/ +│ ├── api/ # API router and endpoints +│ │ └── endpoints/ # Specific API endpoints +│ ├── core/ # Core functionality +│ ├── config/ # Configuration +│ └── services/ # Business logic services +│ ├── leaderboard.py +│ ├── models.py +│ ├── votes.py +│ └── hf_service.py +└── utils/ # Utility functions +``` + +## Technologies + +### Frontend + +- React +- Material-UI +- TanStack Table & Virtual +- Express.js + +### Backend + +- FastAPI +- Hugging Face API +- Docker + +## Development + +The application is containerized using Docker and can be run using: + +```bash +docker-compose up +``` diff --git a/backend/Dockerfile.dev b/backend/Dockerfile.dev new file mode 100644 index 0000000000000000000000000000000000000000..f802c87f0d5d730c559b1f21ed715b48cc9ca42a --- /dev/null +++ b/backend/Dockerfile.dev @@ -0,0 +1,25 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install required system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Install poetry +RUN pip install poetry + +# Copy Poetry configuration files +COPY pyproject.toml poetry.lock* ./ + +# Install dependencies +RUN poetry config virtualenvs.create false && \ + poetry install --no-interaction --no-ansi --no-root + +# Environment variables configuration for logs +ENV PYTHONUNBUFFERED=1 +ENV LOG_LEVEL=INFO + +# In dev, mount volume directly +CMD ["uvicorn", "app.asgi:app", "--host", "0.0.0.0", "--port", "7860", "--reload", "--log-level", "warning", "--no-access-log"] \ No newline at end of file diff --git a/backend/README.md b/backend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4a9c1e60d0c77b9add02f0c9ba25acaabe6ab2a5 --- /dev/null +++ b/backend/README.md @@ -0,0 +1,352 @@ +# Backend - Open LLM Leaderboard 🏆 + +FastAPI backend for the Open LLM Leaderboard. This service is part of a larger architecture that includes a React frontend. For complete project installation, see the [main README](../README.md). + +## ✨ Features + +- 📊 REST API for LLM models leaderboard management +- 🗳️ Voting and ranking system +- 🔄 HuggingFace Hub integration +- 🚀 Caching and performance optimizations + +## 🏗 Architecture + +```mermaid +flowchart TD + Client(["**Frontend**

React Application"]) --> API["**API Server**

FastAPI REST Endpoints"] + + subgraph Backend + API --> Core["**Core Layer**

• Middleware
• Cache
• Rate Limiting"] + Core --> Services["**Services Layer**

• Business Logic
• Data Processing"] + + subgraph Services Layer + Services --> Models["**Model Service**

• Model Submission
• Evaluation Pipeline"] + Services --> Votes["**Vote Service**

• Vote Management
• Data Synchronization"] + Services --> Board["**Leaderboard Service**

• Rankings
• Performance Metrics"] + end + + Models --> Cache["**Cache Layer**

• In-Memory Store
• Auto Invalidation"] + Votes --> Cache + Board --> Cache + + Models --> HF["**HuggingFace Hub**

• Models Repository
• Datasets Access"] + Votes --> HF + Board --> HF + end + + style Client fill:#f9f,stroke:#333,stroke-width:2px + style Models fill:#bbf,stroke:#333,stroke-width:2px + style Votes fill:#bbf,stroke:#333,stroke-width:2px + style Board fill:#bbf,stroke:#333,stroke-width:2px + style HF fill:#bfb,stroke:#333,stroke-width:2px +``` + +## 🛠️ HuggingFace Datasets + +The application uses several datasets on the HuggingFace Hub: + +### 1. Requests Dataset (`{HF_ORGANIZATION}/requests`) + +- **Operations**: + - 📤 `POST /api/models/submit`: Adds a JSON file for each new model submission + - 📥 `GET /api/models/status`: Reads files to get models status +- **Format**: One JSON file per model with submission details +- **Updates**: On each new model submission + +### 2. Votes Dataset (`{HF_ORGANIZATION}/votes`) + +- **Operations**: + - 📤 `POST /api/votes/{model_id}`: Adds a new vote + - 📥 `GET /api/votes/model/{provider}/{model}`: Reads model votes + - 📥 `GET /api/votes/user/{user_id}`: Reads user votes +- **Format**: JSONL with one vote per line +- **Sync**: Bidirectional between local cache and Hub + +### 3. Contents Dataset (`{HF_ORGANIZATION}/contents`) + +- **Operations**: + - 📥 `GET /api/leaderboard`: Reads raw data + - 📥 `GET /api/leaderboard/formatted`: Reads and formats data +- **Format**: Main dataset containing all scores and metrics +- **Updates**: Automatic after model evaluations + +### 4. Official Providers Dataset (`{HF_ORGANIZATION}/official-providers`) + +- **Operations**: + - 📥 Read-only access for highlighted models +- **Format**: List of models selected by maintainers +- **Updates**: Manual by maintainers + +## 🛠 Local Development + +### Prerequisites + +- Python 3.9+ +- [Poetry](https://python-poetry.org/docs/#installation) + +### Standalone Installation (without Docker) + +```bash +# Install dependencies +poetry install + +# Setup configuration +cp .env.example .env + +# Start development server +poetry run uvicorn app.asgi:app --host 0.0.0.0 --port 7860 --reload +``` + +Server will be available at http://localhost:7860 + +## ⚙️ Configuration + +| Variable | Description | Default | +| ------------ | ------------------------------------ | ----------- | +| ENVIRONMENT | Environment (development/production) | development | +| HF_TOKEN | HuggingFace authentication token | - | +| PORT | Server port | 7860 | +| LOG_LEVEL | Logging level (INFO/DEBUG/WARNING) | INFO | +| CORS_ORIGINS | Allowed CORS origins | ["*"] | +| CACHE_TTL | Cache Time To Live in seconds | 300 | + +## 🔧 Middleware + +The backend uses several middleware layers for optimal performance and security: + +- **CORS Middleware**: Handles Cross-Origin Resource Sharing +- **GZIP Middleware**: Compresses responses > 500 bytes +- **Rate Limiting**: Prevents API abuse +- **Caching**: In-memory caching with automatic invalidation + +## 📝 Logging + +The application uses a structured logging system with: + +- Formatted console output +- Different log levels per component +- Request/Response logging +- Performance metrics +- Error tracking + +## 📁 File Structure + +``` +backend/ +├── app/ # Source code +│ ├── api/ # Routes and endpoints +│ │ └── endpoints/ # Endpoint handlers +│ ├── core/ # Configurations +│ ├── services/ # Business logic +│ └── utils/ # Utilities +└── tests/ # Tests +``` + +## 📚 API + +Swagger documentation available at http://localhost:7860/docs + +### Main Endpoints & Data Structures + +#### Leaderboard + +- `GET /api/leaderboard/formatted` - Formatted data with computed fields and metadata + + ```typescript + Response { + models: [{ + id: string, // eval_name + model: { + name: string, // fullname + sha: string, // Model sha + precision: string, // e.g. "fp16", "int8" + type: string, // e.g. "fined-tuned-on-domain-specific-dataset" + weight_type: string, + architecture: string, + average_score: number, + has_chat_template: boolean + }, + evaluations: { + ifeval: { + name: "IFEval", + value: number, // Raw score + normalized_score: number + }, + bbh: { + name: "BBH", + value: number, + normalized_score: number + }, + math: { + name: "MATH Level 5", + value: number, + normalized_score: number + }, + gpqa: { + name: "GPQA", + value: number, + normalized_score: number + }, + musr: { + name: "MUSR", + value: number, + normalized_score: number + }, + mmlu_pro: { + name: "MMLU-PRO", + value: number, + normalized_score: number + } + }, + features: { + is_not_available_on_hub: boolean, + is_merged: boolean, + is_moe: boolean, + is_flagged: boolean, + is_official_provider: boolean + }, + metadata: { + upload_date: string, + submission_date: string, + generation: string, + base_model: string, + hub_license: string, + hub_hearts: number, + params_billions: number, + co2_cost: number // CO₂ cost in kg + } + }] + } + ``` + +- `GET /api/leaderboard` - Raw data from the HuggingFace dataset + ```typescript + Response { + models: [{ + eval_name: string, + Precision: string, + Type: string, + "Weight type": string, + Architecture: string, + Model: string, + fullname: string, + "Model sha": string, + "Average ⬆️": number, + "Hub License": string, + "Hub ❤️": number, + "#Params (B)": number, + "Available on the hub": boolean, + Merged: boolean, + MoE: boolean, + Flagged: boolean, + "Chat Template": boolean, + "CO₂ cost (kg)": number, + "IFEval Raw": number, + IFEval: number, + "BBH Raw": number, + BBH: number, + "MATH Lvl 5 Raw": number, + "MATH Lvl 5": number, + "GPQA Raw": number, + GPQA: number, + "MUSR Raw": number, + MUSR: number, + "MMLU-PRO Raw": number, + "MMLU-PRO": number, + "Maintainer's Highlight": boolean, + "Upload To Hub Date": string, + "Submission Date": string, + Generation: string, + "Base Model": string + }] + } + ``` + +#### Models + +- `GET /api/models/status` - Get all models grouped by status + ```typescript + Response { + pending: [{ + name: string, + submitter: string, + revision: string, + wait_time: string, + submission_time: string, + status: "PENDING" | "EVALUATING" | "FINISHED", + precision: string + }], + evaluating: Array, + finished: Array + } + ``` +- `GET /api/models/pending` - Get pending models only +- `POST /api/models/submit` - Submit model + + ```typescript + Request { + user_id: string, + model_id: string, + base_model?: string, + precision?: string, + model_type: string + } + + Response { + status: string, + message: string + } + ``` + +- `GET /api/models/{model_id}/status` - Get model status + +#### Votes + +- `POST /api/votes/{model_id}` - Vote + + ```typescript + Request { + vote_type: "up" | "down", + user_id: string // HuggingFace username + } + + Response { + success: boolean, + message: string + } + ``` + +- `GET /api/votes/model/{provider}/{model}` - Get model votes + ```typescript + Response { + total_votes: number, + up_votes: number, + down_votes: number + } + ``` +- `GET /api/votes/user/{user_id}` - Get user votes + ```typescript + Response Array<{ + model_id: string, + vote_type: string, + timestamp: string + }> + ``` + +## 🔒 Authentication + +The backend uses HuggingFace token-based authentication for secure API access. Make sure to: + +1. Set your HF_TOKEN in the .env file +2. Include the token in API requests via Bearer authentication +3. Keep your token secure and never commit it to version control + +## 🚀 Performance + +The backend implements several optimizations: + +- In-memory caching with configurable TTL (Time To Live) +- Batch processing for model evaluations +- Rate limiting for API endpoints +- Efficient database queries with proper indexing +- Automatic cache invalidation for votes diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41bd81293794127ec484666c9a9bf3b2cd0bbe3c --- /dev/null +++ b/backend/app/api/__init__.py @@ -0,0 +1,5 @@ +""" +API package initialization +""" + +__all__ = ["endpoints"] diff --git a/backend/app/api/dependencies.py b/backend/app/api/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..d9feaf42a38b8fa19b327989542659edfc635519 --- /dev/null +++ b/backend/app/api/dependencies.py @@ -0,0 +1,34 @@ +from fastapi import Depends, HTTPException +import logging +from app.services.models import ModelService +from app.services.votes import VoteService +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +model_service = ModelService() +vote_service = VoteService() + +async def get_model_service() -> ModelService: + """Dependency to get ModelService instance""" + try: + logger.info(LogFormatter.info("Initializing model service dependency")) + await model_service.initialize() + logger.info(LogFormatter.success("Model service initialized")) + return model_service + except Exception as e: + error_msg = "Failed to initialize model service" + logger.error(LogFormatter.error(error_msg, e)) + raise HTTPException(status_code=500, detail=str(e)) + +async def get_vote_service() -> VoteService: + """Dependency to get VoteService instance""" + try: + logger.info(LogFormatter.info("Initializing vote service dependency")) + await vote_service.initialize() + logger.info(LogFormatter.success("Vote service initialized")) + return vote_service + except Exception as e: + error_msg = "Failed to initialize vote service" + logger.error(LogFormatter.error(error_msg, e)) + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/backend/app/api/endpoints/leaderboard.py b/backend/app/api/endpoints/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..261e7f7f4e7309eefb979e238025fbca4c7e44f8 --- /dev/null +++ b/backend/app/api/endpoints/leaderboard.py @@ -0,0 +1,49 @@ +from fastapi import APIRouter +from typing import List, Dict, Any +from app.services.leaderboard import LeaderboardService +from app.core.fastapi_cache import cached, build_cache_key +import logging +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) +router = APIRouter() +leaderboard_service = LeaderboardService() + +def leaderboard_key_builder(func, namespace: str = "leaderboard", **kwargs): + """Build cache key for leaderboard data""" + key_type = "raw" if func.__name__ == "get_leaderboard" else "formatted" + key = build_cache_key(namespace, key_type) + logger.debug(LogFormatter.info(f"Built leaderboard cache key: {key}")) + return key + +@router.get("") +@cached(expire=300, key_builder=leaderboard_key_builder) +async def get_leaderboard() -> List[Dict[str, Any]]: + """ + Get raw leaderboard data + Response will be automatically GZIP compressed if size > 500 bytes + """ + try: + logger.info(LogFormatter.info("Fetching raw leaderboard data")) + data = await leaderboard_service.fetch_raw_data() + logger.info(LogFormatter.success(f"Retrieved {len(data)} leaderboard entries")) + return data + except Exception as e: + logger.error(LogFormatter.error("Failed to fetch raw leaderboard data", e)) + raise + +@router.get("/formatted") +@cached(expire=300, key_builder=leaderboard_key_builder) +async def get_formatted_leaderboard() -> List[Dict[str, Any]]: + """ + Get formatted leaderboard data with restructured objects + Response will be automatically GZIP compressed if size > 500 bytes + """ + try: + logger.info(LogFormatter.info("Fetching formatted leaderboard data")) + data = await leaderboard_service.get_formatted_data() + logger.info(LogFormatter.success(f"Retrieved {len(data)} formatted entries")) + return data + except Exception as e: + logger.error(LogFormatter.error("Failed to fetch formatted leaderboard data", e)) + raise \ No newline at end of file diff --git a/backend/app/api/endpoints/models.py b/backend/app/api/endpoints/models.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba9bbe412296e55cf91b67199b24aed078fbebb --- /dev/null +++ b/backend/app/api/endpoints/models.py @@ -0,0 +1,116 @@ +from fastapi import APIRouter, HTTPException, Depends, Query +from typing import Dict, Any, List +import logging +from app.services.models import ModelService +from app.api.dependencies import get_model_service +from app.core.fastapi_cache import cached +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["models"]) + +@router.get("/status") +@cached(expire=300) +async def get_models_status( + model_service: ModelService = Depends(get_model_service) +) -> Dict[str, List[Dict[str, Any]]]: + """Get all models grouped by status""" + try: + logger.info(LogFormatter.info("Fetching status for all models")) + result = await model_service.get_models() + stats = { + status: len(models) for status, models in result.items() + } + for line in LogFormatter.stats(stats, "Models by Status"): + logger.info(line) + return result + except Exception as e: + logger.error(LogFormatter.error("Failed to get models status", e)) + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/pending") +@cached(expire=60) +async def get_pending_models( + model_service: ModelService = Depends(get_model_service) +) -> List[Dict[str, Any]]: + """Get all models waiting for evaluation""" + try: + logger.info(LogFormatter.info("Fetching pending models")) + models = await model_service.get_models() + pending = models.get("pending", []) + logger.info(LogFormatter.success(f"Found {len(pending)} pending models")) + return pending + except Exception as e: + logger.error(LogFormatter.error("Failed to get pending models", e)) + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/submit") +async def submit_model( + model_data: Dict[str, Any], + model_service: ModelService = Depends(get_model_service) +) -> Dict[str, Any]: + try: + logger.info(LogFormatter.section("MODEL SUBMISSION")) + + user_id = model_data.pop('user_id', None) + if not user_id: + error_msg = "user_id is required" + logger.error(LogFormatter.error("Validation failed", error_msg)) + raise ValueError(error_msg) + + # Log submission details + submission_info = { + "Model_ID": model_data.get("model_id"), + "User": user_id, + "Base_Model": model_data.get("base_model"), + "Precision": model_data.get("precision"), + "Model_Type": model_data.get("model_type") + } + for line in LogFormatter.tree(submission_info, "Submission Details"): + logger.info(line) + + result = await model_service.submit_model(model_data, user_id) + logger.info(LogFormatter.success("Model submitted successfully")) + return result + + except ValueError as e: + logger.error(LogFormatter.error("Invalid submission data", e)) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(LogFormatter.error("Submission failed", e)) + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/organization/{organization}/submissions") +async def get_organization_submissions( + organization: str, + days: int = Query(default=7, ge=1, le=30), + model_service: ModelService = Depends(get_model_service) +) -> List[Dict[str, Any]]: + """Get all submissions from an organization in the last n days""" + try: + submissions = await model_service.get_organization_submissions(organization, days) + return submissions + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/{model_id}/status") +async def get_model_status( + model_id: str, + model_service: ModelService = Depends(get_model_service) +) -> Dict[str, Any]: + try: + logger.info(LogFormatter.info(f"Checking status for model: {model_id}")) + status = await model_service.get_model_status(model_id) + + if status["status"] != "not_found": + logger.info(LogFormatter.success("Status found")) + for line in LogFormatter.tree(status, "Model Status"): + logger.info(line) + else: + logger.warning(LogFormatter.warning(f"No status found for model: {model_id}")) + + return status + + except Exception as e: + logger.error(LogFormatter.error("Failed to get model status", e)) + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/backend/app/api/endpoints/votes.py b/backend/app/api/endpoints/votes.py new file mode 100644 index 0000000000000000000000000000000000000000..ec78b20e3b207751ce46ed31ea75efe487fe14dc --- /dev/null +++ b/backend/app/api/endpoints/votes.py @@ -0,0 +1,126 @@ +from fastapi import APIRouter, HTTPException, Query, Depends, Response +from typing import Dict, Any, List +from app.services.votes import VoteService +from app.core.fastapi_cache import cached, build_cache_key, invalidate_cache_key +import logging +from app.core.formatting import LogFormatter +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) +router = APIRouter() +vote_service = VoteService() + +CACHE_TTL = 30 # 30 seconds cache + +def model_votes_key_builder(func, namespace: str = "model_votes", **kwargs): + """Build cache key for model votes""" + provider = kwargs.get('provider') + model = kwargs.get('model') + key = build_cache_key(namespace, provider, model) + logger.debug(LogFormatter.info(f"Built model votes cache key: {key}")) + return key + +def user_votes_key_builder(func, namespace: str = "user_votes", **kwargs): + """Build cache key for user votes""" + user_id = kwargs.get('user_id') + key = build_cache_key(namespace, user_id) + logger.debug(LogFormatter.info(f"Built user votes cache key: {key}")) + return key + +@router.post("/{model_id:path}") +async def add_vote( + response: Response, + model_id: str, + vote_type: str = Query(..., description="Type of vote (up/down)"), + user_id: str = Query(..., description="HuggingFace username"), + vote_data: Dict[str, Any] = None +) -> Dict[str, Any]: + try: + logger.info(LogFormatter.section("ADDING VOTE")) + stats = { + "Model": model_id, + "User": user_id, + "Type": vote_type, + "Config": vote_data or {} + } + for line in LogFormatter.tree(stats, "Vote Details"): + logger.info(line) + + await vote_service.initialize() + result = await vote_service.add_vote(model_id, user_id, vote_type, vote_data) + + # Invalidate affected caches + try: + logger.info(LogFormatter.subsection("CACHE INVALIDATION")) + provider, model = model_id.split('/', 1) + + # Build and invalidate cache keys + model_cache_key = build_cache_key("model_votes", provider, model) + user_cache_key = build_cache_key("user_votes", user_id) + + await invalidate_cache_key(model_cache_key) + await invalidate_cache_key(user_cache_key) + + cache_stats = { + "Model_Cache": model_cache_key, + "User_Cache": user_cache_key + } + for line in LogFormatter.tree(cache_stats, "Invalidated Caches"): + logger.info(line) + + except Exception as e: + logger.error(LogFormatter.error("Failed to invalidate cache", e)) + + # Add cache control headers + response.headers["Cache-Control"] = "no-cache" + + return result + except Exception as e: + logger.error(LogFormatter.error("Failed to add vote", e)) + raise HTTPException(status_code=400, detail=str(e)) + +@router.get("/model/{provider}/{model}") +@cached(expire=CACHE_TTL, key_builder=model_votes_key_builder) +async def get_model_votes( + response: Response, + provider: str, + model: str +) -> Dict[str, Any]: + """Get all votes for a specific model""" + try: + logger.info(LogFormatter.info(f"Fetching votes for model: {provider}/{model}")) + await vote_service.initialize() + model_id = f"{provider}/{model}" + result = await vote_service.get_model_votes(model_id) + + # Add cache control headers + response.headers["Cache-Control"] = f"max-age={CACHE_TTL}" + response.headers["Last-Modified"] = vote_service._last_sync.strftime("%a, %d %b %Y %H:%M:%S GMT") + + logger.info(LogFormatter.success(f"Found {result.get('total_votes', 0)} votes")) + return result + except Exception as e: + logger.error(LogFormatter.error("Failed to get model votes", e)) + raise HTTPException(status_code=400, detail=str(e)) + +@router.get("/user/{user_id}") +@cached(expire=CACHE_TTL, key_builder=user_votes_key_builder) +async def get_user_votes( + response: Response, + user_id: str +) -> List[Dict[str, Any]]: + """Get all votes from a specific user""" + try: + logger.info(LogFormatter.info(f"Fetching votes for user: {user_id}")) + await vote_service.initialize() + votes = await vote_service.get_user_votes(user_id) + + # Add cache control headers + response.headers["Cache-Control"] = f"max-age={CACHE_TTL}" + response.headers["Last-Modified"] = vote_service._last_sync.strftime("%a, %d %b %Y %H:%M:%S GMT") + + logger.info(LogFormatter.success(f"Found {len(votes)} votes")) + return votes + except Exception as e: + logger.error(LogFormatter.error("Failed to get user votes", e)) + raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file diff --git a/backend/app/api/router.py b/backend/app/api/router.py new file mode 100644 index 0000000000000000000000000000000000000000..a2c952105c729b92abc72d59ae5882ee4394c017 --- /dev/null +++ b/backend/app/api/router.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +from app.api.endpoints import leaderboard, votes, models + +router = APIRouter() + +router.include_router(leaderboard.router, prefix="/leaderboard", tags=["leaderboard"]) +router.include_router(votes.router, prefix="/votes", tags=["votes"]) +router.include_router(models.router, prefix="/models", tags=["models"]) \ No newline at end of file diff --git a/backend/app/asgi.py b/backend/app/asgi.py new file mode 100644 index 0000000000000000000000000000000000000000..4972047f0588791a59cf20ef2fa280e9ca98d38a --- /dev/null +++ b/backend/app/asgi.py @@ -0,0 +1,106 @@ +""" +ASGI entry point for the Open LLM Leaderboard API. +""" +import os +import uvicorn +import logging +import logging.config +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +import sys + +from app.api.router import router +from app.core.fastapi_cache import setup_cache +from app.core.formatting import LogFormatter +from app.config import hf_config + +# Configure logging before anything else +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "default": { + "format": "%(name)s - %(levelname)s - %(message)s", + } + }, + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + } + }, + "loggers": { + "uvicorn": { + "handlers": ["default"], + "level": "WARNING", + "propagate": False, + }, + "uvicorn.error": { + "level": "WARNING", + "handlers": ["default"], + "propagate": False, + }, + "uvicorn.access": { + "handlers": ["default"], + "level": "WARNING", + "propagate": False, + }, + "app": { + "handlers": ["default"], + "level": "WARNING", + "propagate": False, + } + }, + "root": { + "handlers": ["default"], + "level": "WARNING", + } +} + +# Apply logging configuration +logging.config.dictConfig(LOGGING_CONFIG) +logger = logging.getLogger("app") + +# Create FastAPI application +app = FastAPI( + title="Open LLM Leaderboard", + version="1.0.0", + docs_url="/docs", +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Add GZIP compression +app.add_middleware(GZipMiddleware, minimum_size=500) + +# Include API router +app.include_router(router, prefix="/api") + +@app.on_event("startup") +async def startup_event(): + """Initialize services on startup""" + logger.info("\n") + logger.info(LogFormatter.section("APPLICATION STARTUP")) + + # Log HF configuration + logger.info(LogFormatter.section("HUGGING FACE CONFIGURATION")) + logger.info(LogFormatter.info(f"Organization: {hf_config.HF_ORGANIZATION}")) + logger.info(LogFormatter.info(f"Token Status: {'Present' if hf_config.HF_TOKEN else 'Missing'}")) + logger.info(LogFormatter.info(f"Using repositories:")) + logger.info(LogFormatter.info(f" - Queue: {hf_config.QUEUE_REPO}")) + logger.info(LogFormatter.info(f" - Aggregated: {hf_config.AGGREGATED_REPO}")) + logger.info(LogFormatter.info(f" - Votes: {hf_config.VOTES_REPO}")) + logger.info(LogFormatter.info(f" - Official Providers: {hf_config.OFFICIAL_PROVIDERS_REPO}")) + + # Setup cache + setup_cache() + logger.info(LogFormatter.success("FastAPI Cache initialized with in-memory backend")) \ No newline at end of file diff --git a/backend/app/config/__init__.py b/backend/app/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8cea98b9ddb1daaf3c9e8e5d2c9be1fc94657e --- /dev/null +++ b/backend/app/config/__init__.py @@ -0,0 +1,6 @@ +""" +Configuration module for the Open LLM Leaderboard backend. +All configuration values are imported from base.py to avoid circular dependencies. +""" + +from .base import * diff --git a/backend/app/config/base.py b/backend/app/config/base.py new file mode 100644 index 0000000000000000000000000000000000000000..89a7e65b155fe2d781bc6178fdf2ecea163554b5 --- /dev/null +++ b/backend/app/config/base.py @@ -0,0 +1,38 @@ +import os +from pathlib import Path + +# Server configuration +HOST = "0.0.0.0" +PORT = 7860 +WORKERS = 4 +RELOAD = True if os.environ.get("ENVIRONMENT") == "development" else False + +# CORS configuration +ORIGINS = ["http://localhost:3000"] if os.getenv("ENVIRONMENT") == "development" else ["*"] + +# Cache configuration +CACHE_TTL = int(os.environ.get("CACHE_TTL", 300)) # 5 minutes default + +# Rate limiting +RATE_LIMIT_PERIOD = 7 # days +RATE_LIMIT_QUOTA = 5 +HAS_HIGHER_RATE_LIMIT = [] + +# HuggingFace configuration +HF_TOKEN = os.environ.get("HF_TOKEN") +HF_ORGANIZATION = "open-llm-leaderboard" +API = { + "INFERENCE": "https://api-inference.huggingface.co/models", + "HUB": "https://huggingface.co" +} + +# Cache paths +CACHE_ROOT = Path(os.environ.get("HF_HOME", ".cache")) +DATASETS_CACHE = CACHE_ROOT / "datasets" +MODELS_CACHE = CACHE_ROOT / "models" +VOTES_CACHE = CACHE_ROOT / "votes" +EVAL_CACHE = CACHE_ROOT / "eval-queue" + +# Repository configuration +QUEUE_REPO = f"{HF_ORGANIZATION}/requests" +EVAL_REQUESTS_PATH = EVAL_CACHE / "eval_requests.jsonl" \ No newline at end of file diff --git a/backend/app/config/hf_config.py b/backend/app/config/hf_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c1c6ee93de45159127c4f861dc537bff63917b --- /dev/null +++ b/backend/app/config/hf_config.py @@ -0,0 +1,30 @@ +import os +import logging +from typing import Optional +from huggingface_hub import HfApi +from pathlib import Path +from app.core.cache import cache_config + +logger = logging.getLogger(__name__) + +# Organization or user who owns the datasets +HF_ORGANIZATION = "open-llm-leaderboard" + +# Get HF token directly from environment +HF_TOKEN = os.environ.get("HF_TOKEN") +if not HF_TOKEN: + logger.warning("HF_TOKEN not found in environment variables. Some features may be limited.") + +# Initialize HF API +API = HfApi(token=HF_TOKEN) + +# Repository configuration +QUEUE_REPO = f"{HF_ORGANIZATION}/requests" +AGGREGATED_REPO = f"{HF_ORGANIZATION}/contents" +VOTES_REPO = f"{HF_ORGANIZATION}/votes" +OFFICIAL_PROVIDERS_REPO = f"{HF_ORGANIZATION}/official-providers" + +# File paths from cache config +VOTES_PATH = cache_config.votes_file +EVAL_REQUESTS_PATH = cache_config.eval_requests_file +MODEL_CACHE_DIR = cache_config.models_cache \ No newline at end of file diff --git a/backend/app/config/logging_config.py b/backend/app/config/logging_config.py new file mode 100644 index 0000000000000000000000000000000000000000..96be6f6749cdd79defb975141d857ff216aac420 --- /dev/null +++ b/backend/app/config/logging_config.py @@ -0,0 +1,38 @@ +import logging +import sys +from tqdm import tqdm + +def get_tqdm_handler(): + """ + Creates a special handler for tqdm that doesn't interfere with other logs. + """ + class TqdmLoggingHandler(logging.Handler): + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) + + return TqdmLoggingHandler() + +def setup_service_logger(service_name: str) -> logging.Logger: + """ + Configure a specific logger for a given service. + """ + logger = logging.getLogger(f"app.services.{service_name}") + + # If the logger already has handlers, don't reconfigure it + if logger.handlers: + return logger + + # Add tqdm handler for this service + tqdm_handler = get_tqdm_handler() + tqdm_handler.setFormatter(logging.Formatter('%(name)s - %(levelname)s - %(message)s')) + logger.addHandler(tqdm_handler) + + # Don't propagate logs to parent loggers + logger.propagate = False + + return logger \ No newline at end of file diff --git a/backend/app/core/cache.py b/backend/app/core/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..070f81bc1bb2932109f69b9dc67332b18a1a5fbb --- /dev/null +++ b/backend/app/core/cache.py @@ -0,0 +1,109 @@ +import os +import shutil +from pathlib import Path +from datetime import timedelta +import logging +from app.core.formatting import LogFormatter +from app.config.base import ( + CACHE_ROOT, + DATASETS_CACHE, + MODELS_CACHE, + VOTES_CACHE, + EVAL_CACHE, + CACHE_TTL +) + +logger = logging.getLogger(__name__) + +class CacheConfig: + def __init__(self): + # Get cache paths from config + self.cache_root = CACHE_ROOT + self.datasets_cache = DATASETS_CACHE + self.models_cache = MODELS_CACHE + self.votes_cache = VOTES_CACHE + self.eval_cache = EVAL_CACHE + + # Specific files + self.votes_file = self.votes_cache / "votes_data.jsonl" + self.eval_requests_file = self.eval_cache / "eval_requests.jsonl" + + # Cache TTL + self.cache_ttl = timedelta(seconds=CACHE_TTL) + + self._initialize_cache_dirs() + self._setup_environment() + + def _initialize_cache_dirs(self): + """Initialize all necessary cache directories""" + try: + logger.info(LogFormatter.section("CACHE INITIALIZATION")) + + cache_dirs = { + "Root": self.cache_root, + "Datasets": self.datasets_cache, + "Models": self.models_cache, + "Votes": self.votes_cache, + "Eval": self.eval_cache + } + + for name, cache_dir in cache_dirs.items(): + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(LogFormatter.success(f"{name} cache directory: {cache_dir}")) + + except Exception as e: + logger.error(LogFormatter.error("Failed to create cache directories", e)) + raise + + def _setup_environment(self): + """Configure HuggingFace environment variables""" + logger.info(LogFormatter.subsection("ENVIRONMENT SETUP")) + + env_vars = { + "HF_HOME": str(self.cache_root), + "HF_DATASETS_CACHE": str(self.datasets_cache) + } + + for var, value in env_vars.items(): + os.environ[var] = value + logger.info(LogFormatter.info(f"Set {var}={value}")) + + + def get_cache_path(self, cache_type: str) -> Path: + """Returns the path for a specific cache type""" + cache_paths = { + "datasets": self.datasets_cache, + "models": self.models_cache, + "votes": self.votes_cache, + "eval": self.eval_cache + } + return cache_paths.get(cache_type, self.cache_root) + + def flush_cache(self, cache_type: str = None): + """Flush specified cache or all caches if no type is specified""" + try: + if cache_type: + logger.info(LogFormatter.section(f"FLUSHING {cache_type.upper()} CACHE")) + cache_dir = self.get_cache_path(cache_type) + if cache_dir.exists(): + stats = { + "Cache_Type": cache_type, + "Directory": str(cache_dir) + } + for line in LogFormatter.tree(stats, "Cache Details"): + logger.info(line) + shutil.rmtree(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(LogFormatter.success("Cache cleared successfully")) + else: + logger.info(LogFormatter.section("FLUSHING ALL CACHES")) + for cache_type in ["datasets", "models", "votes", "eval"]: + self.flush_cache(cache_type) + logger.info(LogFormatter.success("All caches cleared successfully")) + + except Exception as e: + logger.error(LogFormatter.error("Failed to flush cache", e)) + raise + +# Singleton instance of cache configuration +cache_config = CacheConfig() \ No newline at end of file diff --git a/backend/app/core/fastapi_cache.py b/backend/app/core/fastapi_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f68434b64faee4e349a6f500bc242c6df9cb4d0c --- /dev/null +++ b/backend/app/core/fastapi_cache.py @@ -0,0 +1,76 @@ +from fastapi_cache import FastAPICache +from fastapi_cache.backends.inmemory import InMemoryBackend +from fastapi_cache.decorator import cache +from datetime import timedelta +from app.config import CACHE_TTL +import logging +from app.core.formatting import LogFormatter +from typing import Optional, Any + +logger = logging.getLogger(__name__) + +class CustomInMemoryBackend(InMemoryBackend): + def __init__(self): + """Initialize the cache backend""" + super().__init__() + self.cache = {} + + async def delete(self, key: str) -> bool: + """Delete a key from the cache""" + try: + if key in self.cache: + del self.cache[key] + return True + return False + except Exception as e: + logger.error(LogFormatter.error(f"Failed to delete key {key} from cache", e)) + return False + + async def get(self, key: str) -> Any: + """Get a value from the cache""" + return self.cache.get(key) + + async def set(self, key: str, value: Any, expire: Optional[int] = None) -> None: + """Set a value in the cache""" + self.cache[key] = value + +def setup_cache(): + """Initialize FastAPI Cache with in-memory backend""" + try: + logger.info(LogFormatter.section("CACHE INITIALIZATION")) + FastAPICache.init( + backend=CustomInMemoryBackend(), + prefix="fastapi-cache" + ) + logger.info(LogFormatter.success("Cache initialized successfully")) + except Exception as e: + logger.error(LogFormatter.error("Failed to initialize cache", e)) + raise + +async def invalidate_cache_key(key: str): + """Invalidate a specific cache key""" + try: + backend = FastAPICache.get_backend() + if hasattr(backend, 'delete'): + await backend.delete(key) + logger.info(LogFormatter.success(f"Cache invalidated for key: {key}")) + else: + logger.warning(LogFormatter.warning("Cache backend does not support deletion")) + except Exception as e: + logger.error(LogFormatter.error(f"Failed to invalidate cache key: {key}", e)) + +def build_cache_key(*args) -> str: + """Build a cache key from multiple arguments""" + return ":".join(str(arg) for arg in args if arg is not None) + +def cached(expire: int = CACHE_TTL, key_builder=None): + """Decorator for caching endpoint responses + + Args: + expire (int): Cache TTL in seconds + key_builder (callable, optional): Custom key builder function + """ + return cache( + expire=expire, + key_builder=key_builder + ) \ No newline at end of file diff --git a/backend/app/core/formatting.py b/backend/app/core/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5b0643019dcc0eaca92cd695c94aeda64cfc94 --- /dev/null +++ b/backend/app/core/formatting.py @@ -0,0 +1,104 @@ +import logging +from typing import Dict, Any, List, Optional + +logger = logging.getLogger(__name__) + +class LogFormatter: + """Utility class for consistent log formatting across the application""" + + @staticmethod + def section(title: str) -> str: + """Create a section header""" + return f"\n{'='*20} {title.upper()} {'='*20}" + + @staticmethod + def subsection(title: str) -> str: + """Create a subsection header""" + return f"\n{'─'*20} {title} {'─'*20}" + + @staticmethod + def tree(items: Dict[str, Any], title: str = None) -> List[str]: + """Create a tree view of dictionary data""" + lines = [] + if title: + lines.append(f"📊 {title}:") + + # Get the maximum length for alignment + max_key_length = max(len(str(k)) for k in items.keys()) + + # Format each item + for i, (key, value) in enumerate(items.items()): + prefix = "└──" if i == len(items) - 1 else "├──" + if isinstance(value, (int, float)): + value = f"{value:,}" # Add thousand separators + lines.append(f"{prefix} {str(key):<{max_key_length}}: {value}") + + return lines + + @staticmethod + def stats(stats: Dict[str, int], title: str = None) -> List[str]: + """Format statistics with icons""" + lines = [] + if title: + lines.append(f"📊 {title}:") + + # Get the maximum length for alignment + max_key_length = max(len(str(k)) for k in stats.keys()) + + # Format each stat with an appropriate icon + icons = { + "total": "📌", + "success": "✅", + "error": "❌", + "pending": "⏳", + "processing": "⚙️", + "finished": "✨", + "evaluating": "🔄", + "downloads": "⬇️", + "files": "📁", + "cached": "💾", + "size": "📏", + "time": "⏱️", + "rate": "🚀" + } + + # Format each item + for i, (key, value) in enumerate(stats.items()): + prefix = "└──" if i == len(stats) - 1 else "├──" + icon = icons.get(key.lower().split('_')[0], "•") + if isinstance(value, (int, float)): + value = f"{value:,}" # Add thousand separators + lines.append(f"{prefix} {icon} {str(key):<{max_key_length}}: {value}") + + return lines + + @staticmethod + def progress_bar(current: int, total: int, width: int = 20) -> str: + """Create a progress bar""" + percentage = (current * 100) // total + filled = "█" * (percentage * width // 100) + empty = "░" * (width - len(filled)) + return f"{filled}{empty} {percentage:3d}%" + + @staticmethod + def error(message: str, error: Optional[Exception] = None) -> str: + """Format error message""" + error_msg = f"\n❌ Error: {message}" + if error: + error_msg += f"\n └── Details: {str(error)}" + return error_msg + + @staticmethod + def success(message: str) -> str: + """Format success message""" + return f"✅ {message}" + + @staticmethod + def warning(message: str) -> str: + """Format warning message""" + return f"⚠️ {message}" + + @staticmethod + def info(message: str) -> str: + """Format info message""" + return f"ℹ️ {message}" \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..86a00401700d1a97f9c7e3cd67509f51d7808c84 --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,18 @@ +from fastapi import FastAPI +from app.config.logging_config import setup_logging +import logging + +# Initialize logging configuration +setup_logging() +logger = logging.getLogger(__name__) + +app = FastAPI(title="Open LLM Leaderboard API") + +@app.on_event("startup") +async def startup_event(): + logger.info("Starting up the application...") + +# Import and include routers after app initialization +from app.api import models, votes +app.include_router(models.router, prefix="/api", tags=["models"]) +app.include_router(votes.router, prefix="/api", tags=["votes"]) \ No newline at end of file diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..399192f82143e7bf446fa183fa9e7779adab2bd7 --- /dev/null +++ b/backend/app/services/__init__.py @@ -0,0 +1,3 @@ +from . import hf_service, leaderboard, votes, models + +__all__ = ["hf_service", "leaderboard", "votes", "models"] diff --git a/backend/app/services/hf_service.py b/backend/app/services/hf_service.py new file mode 100644 index 0000000000000000000000000000000000000000..5f8ff28aa9ad352e0994848a00c5cae2a6b1f6d7 --- /dev/null +++ b/backend/app/services/hf_service.py @@ -0,0 +1,50 @@ +from typing import Optional +from huggingface_hub import HfApi +from app.config import HF_TOKEN, API +from app.core.cache import cache_config +from app.core.formatting import LogFormatter +import logging + +logger = logging.getLogger(__name__) + +class HuggingFaceService: + def __init__(self): + self.api = API + self.token = HF_TOKEN + self.cache_dir = cache_config.models_cache + + async def check_authentication(self) -> bool: + """Check if the HF token is valid""" + if not self.token: + return False + try: + logger.info(LogFormatter.info("Checking HF token validity...")) + self.api.get_token_permission() + logger.info(LogFormatter.success("HF token is valid")) + return True + except Exception as e: + logger.error(LogFormatter.error("HF token validation failed", e)) + return False + + async def get_user_info(self) -> Optional[dict]: + """Get information about the authenticated user""" + try: + logger.info(LogFormatter.info("Fetching user information...")) + info = self.api.get_token_permission() + logger.info(LogFormatter.success(f"User info retrieved for: {info.get('user', 'Unknown')}")) + return info + except Exception as e: + logger.error(LogFormatter.error("Failed to get user info", e)) + return None + + def _log_repo_operation(self, operation: str, repo: str, details: str = None): + """Helper to log repository operations""" + logger.info(LogFormatter.section(f"HF REPOSITORY OPERATION - {operation.upper()}")) + stats = { + "Operation": operation, + "Repository": repo, + } + if details: + stats["Details"] = details + for line in LogFormatter.tree(stats): + logger.info(line) \ No newline at end of file diff --git a/backend/app/services/leaderboard.py b/backend/app/services/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..a83172295b004945c90e0e679af974b61917ab39 --- /dev/null +++ b/backend/app/services/leaderboard.py @@ -0,0 +1,208 @@ +from app.core.cache import cache_config +from datetime import datetime +from typing import List, Dict, Any +import datasets +from fastapi import HTTPException +import logging +from app.config.base import HF_ORGANIZATION +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +class LeaderboardService: + def __init__(self): + pass + + async def fetch_raw_data(self) -> List[Dict[str, Any]]: + """Fetch raw leaderboard data from HuggingFace dataset""" + try: + logger.info(LogFormatter.section("FETCHING LEADERBOARD DATA")) + logger.info(LogFormatter.info(f"Loading dataset from {HF_ORGANIZATION}/contents")) + + dataset = datasets.load_dataset( + f"{HF_ORGANIZATION}/contents", + cache_dir=cache_config.get_cache_path("datasets") + )["train"] + + df = dataset.to_pandas() + data = df.to_dict('records') + + stats = { + "Total_Entries": len(data), + "Dataset_Size": f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.1f}MB" + } + for line in LogFormatter.stats(stats, "Dataset Statistics"): + logger.info(line) + + return data + + except Exception as e: + logger.error(LogFormatter.error("Failed to fetch leaderboard data", e)) + raise HTTPException(status_code=500, detail=str(e)) + + async def get_formatted_data(self) -> List[Dict[str, Any]]: + """Get formatted leaderboard data""" + try: + logger.info(LogFormatter.section("FORMATTING LEADERBOARD DATA")) + + raw_data = await self.fetch_raw_data() + formatted_data = [] + type_counts = {} + error_count = 0 + + # Initialize progress tracking + total_items = len(raw_data) + logger.info(LogFormatter.info(f"Processing {total_items:,} entries...")) + + for i, item in enumerate(raw_data, 1): + try: + formatted_item = await self.transform_data(item) + formatted_data.append(formatted_item) + + # Count model types + model_type = formatted_item["model"]["type"] + type_counts[model_type] = type_counts.get(model_type, 0) + 1 + + except Exception as e: + error_count += 1 + logger.error(LogFormatter.error(f"Failed to format entry {i}/{total_items}", e)) + continue + + # Log progress every 10% + if i % max(1, total_items // 10) == 0: + progress = (i / total_items) * 100 + logger.info(LogFormatter.info(f"Progress: {LogFormatter.progress_bar(i, total_items)}")) + + # Log final statistics + stats = { + "Total_Processed": total_items, + "Successful": len(formatted_data), + "Failed": error_count + } + logger.info(LogFormatter.section("PROCESSING SUMMARY")) + for line in LogFormatter.stats(stats, "Processing Statistics"): + logger.info(line) + + # Log model type distribution + type_stats = {f"Type_{k}": v for k, v in type_counts.items()} + logger.info(LogFormatter.subsection("MODEL TYPE DISTRIBUTION")) + for line in LogFormatter.stats(type_stats): + logger.info(line) + + return formatted_data + + except Exception as e: + logger.error(LogFormatter.error("Failed to format leaderboard data", e)) + raise HTTPException(status_code=500, detail=str(e)) + + async def transform_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Transform raw data into the format expected by the frontend""" + try: + # Extract model name for logging + model_name = data.get("fullname", "Unknown") + logger.debug(LogFormatter.info(f"Transforming data for model: {model_name}")) + + # Create unique ID combining model name, precision, sha and chat template status + unique_id = f"{data.get('fullname', 'Unknown')}_{data.get('Precision', 'Unknown')}_{data.get('Model sha', 'Unknown')}_{str(data.get('Chat Template', False))}" + + evaluations = { + "ifeval": { + "name": "IFEval", + "value": data.get("IFEval Raw", 0), + "normalized_score": data.get("IFEval", 0) + }, + "bbh": { + "name": "BBH", + "value": data.get("BBH Raw", 0), + "normalized_score": data.get("BBH", 0) + }, + "math": { + "name": "MATH Level 5", + "value": data.get("MATH Lvl 5 Raw", 0), + "normalized_score": data.get("MATH Lvl 5", 0) + }, + "gpqa": { + "name": "GPQA", + "value": data.get("GPQA Raw", 0), + "normalized_score": data.get("GPQA", 0) + }, + "musr": { + "name": "MUSR", + "value": data.get("MUSR Raw", 0), + "normalized_score": data.get("MUSR", 0) + }, + "mmlu_pro": { + "name": "MMLU-PRO", + "value": data.get("MMLU-PRO Raw", 0), + "normalized_score": data.get("MMLU-PRO", 0) + } + } + + features = { + "is_not_available_on_hub": data.get("Available on the hub", False), + "is_merged": data.get("Merged", False), + "is_moe": data.get("MoE", False), + "is_flagged": data.get("Flagged", False), + "is_official_provider": data.get("Official Providers", False) + } + + metadata = { + "upload_date": data.get("Upload To Hub Date"), + "submission_date": data.get("Submission Date"), + "generation": data.get("Generation"), + "base_model": data.get("Base Model"), + "hub_license": data.get("Hub License"), + "hub_hearts": data.get("Hub ❤️"), + "params_billions": data.get("#Params (B)"), + "co2_cost": data.get("CO₂ cost (kg)", 0) + } + + # Clean model type by removing emojis if present + original_type = data.get("Type", "") + model_type = original_type.lower().strip() + + # Remove emojis and parentheses + if "(" in model_type: + model_type = model_type.split("(")[0].strip() + model_type = ''.join(c for c in model_type if not c in '🔶🟢🟩💬🤝🌸 ') + + # Map old model types to new ones + model_type_mapping = { + "fine-tuned": "fined-tuned-on-domain-specific-dataset", + "fine tuned": "fined-tuned-on-domain-specific-dataset", + "finetuned": "fined-tuned-on-domain-specific-dataset", + "fine_tuned": "fined-tuned-on-domain-specific-dataset", + "ft": "fined-tuned-on-domain-specific-dataset", + "finetuning": "fined-tuned-on-domain-specific-dataset", + "fine tuning": "fined-tuned-on-domain-specific-dataset", + "fine-tuning": "fined-tuned-on-domain-specific-dataset" + } + + mapped_type = model_type_mapping.get(model_type.lower().strip(), model_type) + + if mapped_type != model_type: + logger.debug(LogFormatter.info(f"Model type mapped: {original_type} -> {mapped_type}")) + + transformed_data = { + "id": unique_id, + "model": { + "name": data.get("fullname"), + "sha": data.get("Model sha"), + "precision": data.get("Precision"), + "type": mapped_type, + "weight_type": data.get("Weight type"), + "architecture": data.get("Architecture"), + "average_score": data.get("Average ⬆️"), + "has_chat_template": data.get("Chat Template", False) + }, + "evaluations": evaluations, + "features": features, + "metadata": metadata + } + + logger.debug(LogFormatter.success(f"Successfully transformed data for {model_name}")) + return transformed_data + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to transform data for {data.get('fullname', 'Unknown')}", e)) + raise \ No newline at end of file diff --git a/backend/app/services/models.py b/backend/app/services/models.py new file mode 100644 index 0000000000000000000000000000000000000000..834f230e93d72177a741a2d14a1eb050afd23343 --- /dev/null +++ b/backend/app/services/models.py @@ -0,0 +1,668 @@ +from datetime import datetime, timezone, timedelta +from typing import Dict, Any, Optional, List +import json +import os +from pathlib import Path +import logging +import aiohttp +import asyncio +import time +from huggingface_hub import HfApi, CommitOperationAdd +from huggingface_hub.utils import build_hf_headers +from datasets import disable_progress_bar +import sys +import contextlib +from concurrent.futures import ThreadPoolExecutor +import tempfile + +from app.config import ( + QUEUE_REPO, + HF_TOKEN, + EVAL_REQUESTS_PATH +) +from app.config.hf_config import HF_ORGANIZATION +from app.services.hf_service import HuggingFaceService +from app.utils.model_validation import ModelValidator +from app.services.votes import VoteService +from app.core.cache import cache_config +from app.core.formatting import LogFormatter + +# Disable datasets progress bars globally +disable_progress_bar() + +logger = logging.getLogger(__name__) + +# Context manager to temporarily disable stdout and stderr +@contextlib.contextmanager +def suppress_output(): + stdout = sys.stdout + stderr = sys.stderr + devnull = open(os.devnull, 'w') + try: + sys.stdout = devnull + sys.stderr = devnull + yield + finally: + sys.stdout = stdout + sys.stderr = stderr + devnull.close() + +class ProgressTracker: + def __init__(self, total: int, desc: str = "Progress", update_frequency: int = 10): + self.total = total + self.current = 0 + self.desc = desc + self.start_time = time.time() + self.update_frequency = update_frequency # Percentage steps + self.last_update = -1 + + # Initial log with fancy formatting + logger.info(LogFormatter.section(desc)) + logger.info(LogFormatter.info(f"Starting processing of {total:,} items...")) + sys.stdout.flush() + + def update(self, n: int = 1): + self.current += n + current_percentage = (self.current * 100) // self.total + + # Only update on frequency steps (e.g., 0%, 10%, 20%, etc.) + if current_percentage >= self.last_update + self.update_frequency or current_percentage == 100: + elapsed = time.time() - self.start_time + rate = self.current / elapsed if elapsed > 0 else 0 + remaining = (self.total - self.current) / rate if rate > 0 else 0 + + # Create progress stats + stats = { + "Progress": LogFormatter.progress_bar(self.current, self.total), + "Items": f"{self.current:,}/{self.total:,}", + "Time": f"⏱️ {elapsed:.1f}s elapsed, {remaining:.1f}s remaining", + "Rate": f"🚀 {rate:.1f} items/s" + } + + # Log progress using tree format + for line in LogFormatter.tree(stats): + logger.info(line) + sys.stdout.flush() + + self.last_update = (current_percentage // self.update_frequency) * self.update_frequency + + def close(self): + elapsed = time.time() - self.start_time + rate = self.total / elapsed if elapsed > 0 else 0 + + # Final summary with fancy formatting + logger.info(LogFormatter.section("COMPLETED")) + stats = { + "Total": f"{self.total:,} items", + "Time": f"{elapsed:.1f}s", + "Rate": f"{rate:.1f} items/s" + } + for line in LogFormatter.stats(stats): + logger.info(line) + logger.info("="*50) + sys.stdout.flush() + +class ModelService(HuggingFaceService): + _instance: Optional['ModelService'] = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + logger.info(LogFormatter.info("Creating new ModelService instance")) + cls._instance = super(ModelService, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_init_done'): + logger.info(LogFormatter.section("MODEL SERVICE INITIALIZATION")) + super().__init__() + self.validator = ModelValidator() + self.vote_service = VoteService() + self.eval_requests_path = cache_config.eval_requests_file + logger.info(LogFormatter.info(f"Using eval requests path: {self.eval_requests_path}")) + + self.eval_requests_path.parent.mkdir(parents=True, exist_ok=True) + self.hf_api = HfApi(token=HF_TOKEN) + self.cached_models = None + self.last_cache_update = 0 + self.cache_ttl = cache_config.cache_ttl.total_seconds() + self._init_done = True + logger.info(LogFormatter.success("Initialization complete")) + + async def _download_and_process_file(self, file: str, session: aiohttp.ClientSession, progress: ProgressTracker) -> Optional[Dict]: + """Download and process a file asynchronously""" + try: + # Build file URL + url = f"https://huggingface.co/datasets/{QUEUE_REPO}/resolve/main/{file}" + headers = build_hf_headers(token=self.token) + + # Download file + async with session.get(url, headers=headers) as response: + if response.status != 200: + logger.error(LogFormatter.error(f"Failed to download {file}", f"HTTP {response.status}")) + progress.update() + return None + + try: + # First read content as text + text_content = await response.text() + # Then parse JSON + content = json.loads(text_content) + except json.JSONDecodeError as e: + logger.error(LogFormatter.error(f"Failed to decode JSON from {file}", e)) + progress.update() + return None + + # Get status and determine target status + status = content.get("status", "PENDING").upper() + target_status = None + status_map = { + "PENDING": ["PENDING"], + "EVALUATING": ["RUNNING"], + "FINISHED": ["FINISHED"] + } + + for target, source_statuses in status_map.items(): + if status in source_statuses: + target_status = target + break + + if not target_status: + progress.update() + return None + + # Calculate wait time + try: + submit_time = datetime.fromisoformat(content["submitted_time"].replace("Z", "+00:00")) + if submit_time.tzinfo is None: + submit_time = submit_time.replace(tzinfo=timezone.utc) + current_time = datetime.now(timezone.utc) + wait_time = current_time - submit_time + + model_info = { + "name": content["model"], + "submitter": content.get("sender", "Unknown"), + "revision": content["revision"], + "wait_time": f"{wait_time.total_seconds():.1f}s", + "submission_time": content["submitted_time"], + "status": target_status, + "precision": content.get("precision", "Unknown") + } + + progress.update() + return model_info + + except (ValueError, TypeError) as e: + logger.error(LogFormatter.error(f"Failed to process {file}", e)) + progress.update() + return None + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to load {file}", e)) + progress.update() + return None + + async def _refresh_models_cache(self): + """Refresh the models cache""" + try: + logger.info(LogFormatter.section("CACHE REFRESH")) + self._log_repo_operation("read", f"{HF_ORGANIZATION}/requests", "Refreshing models cache") + + # Initialize models dictionary + models = { + "finished": [], + "evaluating": [], + "pending": [] + } + + try: + logger.info(LogFormatter.subsection("DATASET LOADING")) + logger.info(LogFormatter.info("Loading dataset...")) + + # Download entire dataset snapshot + with suppress_output(): + local_dir = self.hf_api.snapshot_download( + repo_id=QUEUE_REPO, + repo_type="dataset", + token=self.token + ) + + # List JSON files in local directory + local_path = Path(local_dir) + json_files = list(local_path.glob("**/*.json")) + total_files = len(json_files) + + # Log repository stats + stats = { + "Total_Files": total_files, + "Local_Dir": str(local_path), + } + for line in LogFormatter.stats(stats, "Repository Statistics"): + logger.info(line) + + if not json_files: + raise Exception("No JSON files found in repository") + + # Initialize progress tracker + progress = ProgressTracker(total_files, "PROCESSING FILES") + + # Process local files + model_submissions = {} # Dict to track latest submission for each (model_id, revision, precision) + for file_path in json_files: + try: + with open(file_path, 'r') as f: + content = json.load(f) + + # Get status and determine target status + status = content.get("status", "PENDING").upper() + target_status = None + status_map = { + "PENDING": ["PENDING"], + "EVALUATING": ["RUNNING"], + "FINISHED": ["FINISHED"] + } + + for target, source_statuses in status_map.items(): + if status in source_statuses: + target_status = target + break + + if not target_status: + progress.update() + continue + + # Calculate wait time + try: + submit_time = datetime.fromisoformat(content["submitted_time"].replace("Z", "+00:00")) + if submit_time.tzinfo is None: + submit_time = submit_time.replace(tzinfo=timezone.utc) + current_time = datetime.now(timezone.utc) + wait_time = current_time - submit_time + + model_info = { + "name": content["model"], + "submitter": content.get("sender", "Unknown"), + "revision": content["revision"], + "wait_time": f"{wait_time.total_seconds():.1f}s", + "submission_time": content["submitted_time"], + "status": target_status, + "precision": content.get("precision", "Unknown") + } + + # Use (model_id, revision, precision) as key to track latest submission + key = (content["model"], content["revision"], content.get("precision", "Unknown")) + if key not in model_submissions or submit_time > datetime.fromisoformat(model_submissions[key]["submission_time"].replace("Z", "+00:00")): + model_submissions[key] = model_info + + except (ValueError, TypeError) as e: + logger.error(LogFormatter.error(f"Failed to process {file_path.name}", e)) + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to load {file_path.name}", e)) + finally: + progress.update() + + # Populate models dict with deduplicated submissions + for model_info in model_submissions.values(): + models[model_info["status"].lower()].append(model_info) + + progress.close() + + # Final summary with fancy formatting + logger.info(LogFormatter.section("CACHE SUMMARY")) + stats = { + "Finished": len(models["finished"]), + "Evaluating": len(models["evaluating"]), + "Pending": len(models["pending"]) + } + for line in LogFormatter.stats(stats, "Models by Status"): + logger.info(line) + logger.info("="*50) + + except Exception as e: + logger.error(LogFormatter.error("Error processing files", e)) + raise + + # Update cache + self.cached_models = models + self.last_cache_update = time.time() + logger.info(LogFormatter.success("Cache updated successfully")) + + return models + + except Exception as e: + logger.error(LogFormatter.error("Cache refresh failed", e)) + raise + + async def initialize(self): + """Initialize the model service""" + if self._initialized: + logger.info(LogFormatter.info("Service already initialized, using cached data")) + return + + try: + logger.info(LogFormatter.section("MODEL SERVICE INITIALIZATION")) + + # Check if cache already exists + cache_path = cache_config.get_cache_path("datasets") + if not cache_path.exists() or not any(cache_path.iterdir()): + logger.info(LogFormatter.info("No existing cache found, initializing datasets cache...")) + cache_config.flush_cache("datasets") + else: + logger.info(LogFormatter.info("Using existing datasets cache")) + + # Ensure eval requests directory exists + self.eval_requests_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(LogFormatter.info(f"Eval requests directory: {self.eval_requests_path}")) + + # List existing files + if self.eval_requests_path.exists(): + files = list(self.eval_requests_path.glob("**/*.json")) + stats = { + "Total_Files": len(files), + "Directory": str(self.eval_requests_path) + } + for line in LogFormatter.stats(stats, "Eval Requests"): + logger.info(line) + + # Load initial cache + await self._refresh_models_cache() + + self._initialized = True + logger.info(LogFormatter.success("Model service initialization complete")) + + except Exception as e: + logger.error(LogFormatter.error("Initialization failed", e)) + raise + + async def get_models(self) -> Dict[str, List[Dict[str, Any]]]: + """Get all models with their status""" + if not self._initialized: + logger.info(LogFormatter.info("Service not initialized, initializing now...")) + await self.initialize() + + current_time = time.time() + cache_age = current_time - self.last_cache_update + + # Check if cache needs refresh + if not self.cached_models: + logger.info(LogFormatter.info("No cached data available, refreshing cache...")) + return await self._refresh_models_cache() + elif cache_age > self.cache_ttl: + logger.info(LogFormatter.info(f"Cache expired ({cache_age:.1f}s old, TTL: {self.cache_ttl}s)")) + return await self._refresh_models_cache() + else: + logger.info(LogFormatter.info(f"Using cached data ({cache_age:.1f}s old)")) + return self.cached_models + + async def submit_model( + self, + model_data: Dict[str, Any], + user_id: str + ) -> Dict[str, Any]: + logger.info(LogFormatter.section("MODEL SUBMISSION")) + self._log_repo_operation("write", f"{HF_ORGANIZATION}/requests", f"Submitting model {model_data['model_id']} by {user_id}") + stats = { + "Model": model_data["model_id"], + "User": user_id, + "Revision": model_data["revision"], + "Precision": model_data["precision"], + "Type": model_data["model_type"] + } + for line in LogFormatter.tree(stats, "Submission Details"): + logger.info(line) + + # Validate required fields + required_fields = [ + "model_id", "base_model", "revision", "precision", + "weight_type", "model_type", "use_chat_template" + ] + for field in required_fields: + if field not in model_data: + raise ValueError(f"Missing required field: {field}") + + # Get model info and validate it exists on HuggingFace + try: + logger.info(LogFormatter.subsection("MODEL VALIDATION")) + + # Get the model info to check if it exists + model_info = self.hf_api.model_info( + model_data["model_id"], + revision=model_data["revision"], + token=self.token + ) + + if not model_info: + raise Exception(f"Model {model_data['model_id']} not found on HuggingFace Hub") + + logger.info(LogFormatter.success("Model exists on HuggingFace Hub")) + + except Exception as e: + logger.error(LogFormatter.error("Model validation failed", e)) + raise + + # Update model revision with commit sha + model_data["revision"] = model_info.sha + + # Check if model already exists in the system + try: + logger.info(LogFormatter.subsection("CHECKING EXISTING SUBMISSIONS")) + existing_models = await self.get_models() + + # Call the official provider status check + is_valid, error_message = await self.validator.check_official_provider_status( + model_data["model_id"], + existing_models + ) + if not is_valid: + raise ValueError(error_message) + + # Check in all statuses (pending, evaluating, finished) + for status, models in existing_models.items(): + for model in models: + if model["name"] == model_data["model_id"] and model["revision"] == model_data["revision"]: + error_msg = f"Model {model_data['model_id']} revision {model_data['revision']} is already in the system with status: {status}" + logger.error(LogFormatter.error("Submission rejected", error_msg)) + raise ValueError(error_msg) + + logger.info(LogFormatter.success("No existing submission found")) + except ValueError: + raise + except Exception as e: + logger.error(LogFormatter.error("Failed to check existing submissions", e)) + raise + + # Check that model on hub and valid + valid, error, model_config = await self.validator.is_model_on_hub( + model_data["model_id"], + model_data["revision"], + test_tokenizer=True + ) + if not valid: + logger.error(LogFormatter.error("Model on hub validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success("Model on hub validation passed")) + + # Validate model card + valid, error, model_card = await self.validator.check_model_card( + model_data["model_id"] + ) + if not valid: + logger.error(LogFormatter.error("Model card validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success("Model card validation passed")) + + # Check size limits + model_size, error = await self.validator.get_model_size( + model_info, + model_data["precision"], + model_data["base_model"], + revision=model_data["revision"] + ) + if model_size is None: + logger.error(LogFormatter.error("Model size validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success(f"Model size validation passed: {model_size:.1f}B")) + + # Size limits based on precision + if model_data["precision"] in ["float16", "bfloat16"] and model_size > 100: + error_msg = f"Model too large for {model_data['precision']} (limit: 100B)" + logger.error(LogFormatter.error("Size limit exceeded", error_msg)) + raise Exception(error_msg) + + # Chat template validation if requested + if model_data["use_chat_template"]: + valid, error = await self.validator.check_chat_template( + model_data["model_id"], + model_data["revision"] + ) + if not valid: + logger.error(LogFormatter.error("Chat template validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success("Chat template validation passed")) + + + architectures = model_info.config.get("architectures", "") + if architectures: + architectures = ";".join(architectures) + + # Create eval entry + eval_entry = { + "model": model_data["model_id"], + "base_model": model_data["base_model"], + "revision": model_info.sha, + "precision": model_data["precision"], + "params": model_size, + "architectures": architectures, + "weight_type": model_data["weight_type"], + "status": "PENDING", + "submitted_time": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "model_type": model_data["model_type"], + "job_id": -1, + "job_start_time": None, + "use_chat_template": model_data["use_chat_template"], + "sender": user_id + } + + logger.info(LogFormatter.subsection("EVALUATION ENTRY")) + for line in LogFormatter.tree(eval_entry): + logger.info(line) + + # Upload to HF dataset + try: + logger.info(LogFormatter.subsection("UPLOADING TO HUGGINGFACE")) + logger.info(LogFormatter.info(f"Uploading to {HF_ORGANIZATION}/requests...")) + + # Construct the path in the dataset + org_or_user = model_data["model_id"].split("/")[0] if "/" in model_data["model_id"] else "" + model_path = model_data["model_id"].split("/")[-1] + relative_path = f"{org_or_user}/{model_path}_eval_request_False_{model_data['precision']}_{model_data['weight_type']}.json" + + # Create a temporary file with the request + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as temp_file: + json.dump(eval_entry, temp_file, indent=2) + temp_file.flush() + temp_path = temp_file.name + + # Upload file directly + self.hf_api.upload_file( + path_or_fileobj=temp_path, + path_in_repo=relative_path, + repo_id=f"{HF_ORGANIZATION}/requests", + repo_type="dataset", + commit_message=f"Add {model_data['model_id']} to eval queue", + token=self.token + ) + + # Clean up temp file + os.unlink(temp_path) + + logger.info(LogFormatter.success("Upload successful")) + + except Exception as e: + logger.error(LogFormatter.error("Upload failed", e)) + raise + + # Add automatic vote + try: + logger.info(LogFormatter.subsection("AUTOMATIC VOTE")) + logger.info(LogFormatter.info(f"Adding upvote for {model_data['model_id']} by {user_id}")) + await self.vote_service.add_vote( + model_data["model_id"], + user_id, + "up", + { + "precision": model_data["precision"], + "revision": model_data["revision"] + } + ) + logger.info(LogFormatter.success("Vote recorded successfully")) + except Exception as e: + logger.error(LogFormatter.error("Failed to record vote", e)) + # Don't raise here as the main submission was successful + + return { + "status": "success", + "message": "The model was submitted successfully, and the vote has been recorded" + } + + async def get_model_status(self, model_id: str) -> Dict[str, Any]: + """Get evaluation status of a model""" + logger.info(LogFormatter.info(f"Checking status for model: {model_id}")) + eval_path = self.eval_requests_path + + for user_folder in eval_path.iterdir(): + if user_folder.is_dir(): + for file in user_folder.glob("*.json"): + with open(file, "r") as f: + data = json.load(f) + if data["model"] == model_id: + status = { + "status": data["status"], + "submitted_time": data["submitted_time"], + "job_id": data.get("job_id", -1) + } + logger.info(LogFormatter.success("Status found")) + for line in LogFormatter.tree(status, "Model Status"): + logger.info(line) + return status + + logger.warning(LogFormatter.warning(f"No status found for model: {model_id}")) + return {"status": "not_found"} + + async def get_organization_submissions(self, organization: str, days: int = 7) -> List[Dict[str, Any]]: + """Get all submissions from a user in the last n days""" + try: + # Get all models + all_models = await self.get_models() + current_time = datetime.now(timezone.utc) + cutoff_time = current_time - timedelta(days=days) + + # Filter models by submitter and submission time + user_submissions = [] + for status, models in all_models.items(): + for model in models: + # Check if model was submitted by the user + if model["submitter"] == organization: + # Parse submission time + submit_time = datetime.fromisoformat( + model["submission_time"].replace("Z", "+00:00") + ) + # Check if within time window + if submit_time > cutoff_time: + user_submissions.append({ + "name": model["name"], + "status": status, + "submission_time": model["submission_time"], + "precision": model["precision"] + }) + + return sorted( + user_submissions, + key=lambda x: x["submission_time"], + reverse=True + ) + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to get submissions for {organization}", e)) + raise \ No newline at end of file diff --git a/backend/app/services/rate_limiter.py b/backend/app/services/rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..988c68e2f7d7f3847d6691c70f55975648aa3c8f --- /dev/null +++ b/backend/app/services/rate_limiter.py @@ -0,0 +1,72 @@ +""" +import logging +from datetime import datetime, timedelta, timezone +from typing import Tuple, Dict, List + +logger = logging.getLogger(__name__) + +class RateLimiter: + def __init__(self, period_days: int = 7, quota: int = 5): + self.period_days = period_days + self.quota = quota + self.submission_history: Dict[str, List[datetime]] = {} + self.higher_quota_users = set() # Users with higher quotas + self.unlimited_users = set() # Users with no quota limits + + def add_unlimited_user(self, user_id: str): + """Add a user to the unlimited users list""" + self.unlimited_users.add(user_id) + + def add_higher_quota_user(self, user_id: str): + """Add a user to the higher quota users list""" + self.higher_quota_users.add(user_id) + + def record_submission(self, user_id: str): + """Record a new submission for a user""" + current_time = datetime.now(timezone.utc) + if user_id not in self.submission_history: + self.submission_history[user_id] = [] + self.submission_history[user_id].append(current_time) + + def clean_old_submissions(self, user_id: str): + """Remove submissions older than the period""" + if user_id not in self.submission_history: + return + + current_time = datetime.now(timezone.utc) + cutoff_time = current_time - timedelta(days=self.period_days) + + self.submission_history[user_id] = [ + time for time in self.submission_history[user_id] + if time > cutoff_time + ] + + async def check_rate_limit(self, user_id: str) -> Tuple[bool, str]: + """Check if a user has exceeded their rate limit + + Returns: + Tuple[bool, str]: (is_allowed, error_message) + """ + # Unlimited users bypass all checks + if user_id in self.unlimited_users: + return True, "" + + # Clean old submissions + self.clean_old_submissions(user_id) + + # Get current submission count + submission_count = len(self.submission_history.get(user_id, [])) + + # Calculate user's quota + user_quota = self.quota * 2 if user_id in self.higher_quota_users else self.quota + + # Check if user has exceeded their quota + if submission_count >= user_quota: + error_msg = ( + f"User '{user_id}' has reached the limit of {user_quota} submissions " + f"in the last {self.period_days} days. Please wait before submitting again." + ) + return False, error_msg + + return True, "" +""" \ No newline at end of file diff --git a/backend/app/services/votes.py b/backend/app/services/votes.py new file mode 100644 index 0000000000000000000000000000000000000000..929062ac34752f70f279af30d55b39276440cda4 --- /dev/null +++ b/backend/app/services/votes.py @@ -0,0 +1,441 @@ +from datetime import datetime, timezone +from typing import Dict, Any, List, Set, Tuple, Optional +import json +import logging +import asyncio +from pathlib import Path +import aiohttp +from huggingface_hub import HfApi +import tempfile +import os + +from app.services.hf_service import HuggingFaceService +from app.config import HF_TOKEN +from app.config.hf_config import HF_ORGANIZATION +from app.core.cache import cache_config +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +class VoteService(HuggingFaceService): + _instance: Optional['VoteService'] = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super(VoteService, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_init_done'): + super().__init__() + self.votes_file = cache_config.votes_file + self.votes_to_upload: List[Dict[str, Any]] = [] + self.vote_check_set: Set[Tuple[str, str, str, str]] = set() + self._votes_by_model: Dict[str, List[Dict[str, Any]]] = {} + self._votes_by_user: Dict[str, List[Dict[str, Any]]] = {} + self._last_sync = None + self._sync_interval = 300 # 5 minutes + self._total_votes = 0 + self._last_vote_timestamp = None + self._max_retries = 3 + self._retry_delay = 1 # seconds + self.hf_api = HfApi(token=HF_TOKEN) + self._init_done = True + + async def initialize(self): + """Initialize the vote service""" + if self._initialized: + await self._check_for_new_votes() + return + + try: + logger.info(LogFormatter.section("VOTE SERVICE INITIALIZATION")) + + # Ensure votes directory exists + self.votes_file.parent.mkdir(parents=True, exist_ok=True) + + # Load remote votes + remote_votes = await self._fetch_remote_votes() + if remote_votes: + logger.info(LogFormatter.info(f"Loaded {len(remote_votes)} votes from hub")) + + # Save to local file + with open(self.votes_file, 'w') as f: + for vote in remote_votes: + json.dump(vote, f) + f.write('\n') + + # Load into memory + await self._load_existing_votes() + else: + logger.warning(LogFormatter.warning("No votes found on hub")) + + self._initialized = True + self._last_sync = datetime.now(timezone.utc) + + # Final summary + stats = { + "Total_Votes": self._total_votes, + "Last_Sync": self._last_sync.strftime("%Y-%m-%d %H:%M:%S UTC") + } + logger.info(LogFormatter.section("INITIALIZATION COMPLETE")) + for line in LogFormatter.stats(stats): + logger.info(line) + + except Exception as e: + logger.error(LogFormatter.error("Initialization failed", e)) + raise + + async def _fetch_remote_votes(self) -> List[Dict[str, Any]]: + """Fetch votes from HF hub""" + url = f"https://huggingface.co/datasets/{HF_ORGANIZATION}/votes/raw/main/votes_data.jsonl" + headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status == 200: + votes = [] + async for line in response.content: + if line.strip(): + try: + vote = json.loads(line.decode()) + votes.append(vote) + except json.JSONDecodeError: + continue + return votes + else: + logger.error(f"Failed to get remote votes: HTTP {response.status}") + return [] + except Exception as e: + logger.error(f"Error fetching remote votes: {str(e)}") + return [] + + async def _check_for_new_votes(self): + """Check for new votes on the hub and sync if needed""" + try: + remote_votes = await self._fetch_remote_votes() + if len(remote_votes) != self._total_votes: + logger.info(f"Vote count changed: Local ({self._total_votes}) ≠ Remote ({len(remote_votes)})") + # Save to local file + with open(self.votes_file, 'w') as f: + for vote in remote_votes: + json.dump(vote, f) + f.write('\n') + + # Reload into memory + await self._load_existing_votes() + else: + logger.info("Votes are in sync") + + except Exception as e: + logger.error(f"Error checking for new votes: {str(e)}") + + async def _sync_with_hub(self): + """Sync votes with HuggingFace hub""" + try: + logger.info(LogFormatter.section("VOTE SYNC")) + + # Get current remote votes + remote_votes = await self._fetch_remote_votes() + logger.info(LogFormatter.info(f"Loaded {len(remote_votes)} votes from hub")) + + # If we have pending votes to upload + if self.votes_to_upload: + logger.info(LogFormatter.info(f"Adding {len(self.votes_to_upload)} pending votes...")) + + # Add new votes to remote votes + remote_votes.extend(self.votes_to_upload) + + # Create temporary file with all votes + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as temp_file: + for vote in remote_votes: + json.dump(vote, temp_file) + temp_file.write('\n') + temp_path = temp_file.name + + try: + # Upload JSONL file directly + self.hf_api.upload_file( + path_or_fileobj=temp_path, + path_in_repo="votes_data.jsonl", + repo_id=f"{HF_ORGANIZATION}/votes", + repo_type="dataset", + commit_message=f"Update votes: +{len(self.votes_to_upload)} new votes", + token=self.token + ) + + # Clear pending votes only if upload succeeded + self.votes_to_upload.clear() + logger.info(LogFormatter.success("Pending votes uploaded successfully")) + + except Exception as e: + logger.error(LogFormatter.error("Failed to upload votes to hub", e)) + raise + finally: + # Clean up temp file + os.unlink(temp_path) + + # Update local state + with open(self.votes_file, 'w') as f: + for vote in remote_votes: + json.dump(vote, f) + f.write('\n') + + # Reload votes in memory + await self._load_existing_votes() + logger.info(LogFormatter.success("Sync completed successfully")) + + self._last_sync = datetime.now(timezone.utc) + + except Exception as e: + logger.error(LogFormatter.error("Sync failed", e)) + raise + + async def _load_existing_votes(self): + """Load existing votes from file""" + if not self.votes_file.exists(): + logger.warning(LogFormatter.warning("No votes file found")) + return + + try: + logger.info(LogFormatter.section("LOADING VOTES")) + + # Clear existing data structures + self.vote_check_set.clear() + self._votes_by_model.clear() + self._votes_by_user.clear() + + vote_count = 0 + latest_timestamp = None + + with open(self.votes_file, "r") as f: + for line in f: + try: + vote = json.loads(line.strip()) + vote_count += 1 + + # Track latest timestamp + try: + vote_timestamp = datetime.fromisoformat(vote["timestamp"].replace("Z", "+00:00")) + if not latest_timestamp or vote_timestamp > latest_timestamp: + latest_timestamp = vote_timestamp + vote["timestamp"] = vote_timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") + except (KeyError, ValueError) as e: + logger.warning(LogFormatter.warning(f"Invalid timestamp in vote: {str(e)}")) + continue + + if vote_count % 1000 == 0: + logger.info(LogFormatter.info(f"Processed {vote_count:,} votes...")) + + self._add_vote_to_memory(vote) + + except json.JSONDecodeError as e: + logger.error(LogFormatter.error("Vote parsing failed", e)) + continue + except Exception as e: + logger.error(LogFormatter.error("Vote processing failed", e)) + continue + + self._total_votes = vote_count + self._last_vote_timestamp = latest_timestamp + + # Final summary + stats = { + "Total_Votes": vote_count, + "Latest_Vote": latest_timestamp.strftime("%Y-%m-%d %H:%M:%S UTC") if latest_timestamp else "None", + "Unique_Models": len(self._votes_by_model), + "Unique_Users": len(self._votes_by_user) + } + + logger.info(LogFormatter.section("VOTE SUMMARY")) + for line in LogFormatter.stats(stats): + logger.info(line) + + except Exception as e: + logger.error(LogFormatter.error("Failed to load votes", e)) + raise + + def _add_vote_to_memory(self, vote: Dict[str, Any]): + """Add vote to memory structures""" + try: + # Create a unique identifier tuple that includes precision + check_tuple = ( + vote["model"], + vote.get("revision", "main"), + vote["username"], + vote.get("precision", "unknown") + ) + + # Skip if we already have this vote + if check_tuple in self.vote_check_set: + return + + self.vote_check_set.add(check_tuple) + + # Update model votes + if vote["model"] not in self._votes_by_model: + self._votes_by_model[vote["model"]] = [] + self._votes_by_model[vote["model"]].append(vote) + + # Update user votes + if vote["username"] not in self._votes_by_user: + self._votes_by_user[vote["username"]] = [] + self._votes_by_user[vote["username"]].append(vote) + + except KeyError as e: + logger.error(LogFormatter.error("Malformed vote data, missing key", str(e))) + except Exception as e: + logger.error(LogFormatter.error("Error adding vote to memory", str(e))) + + async def get_user_votes(self, user_id: str) -> List[Dict[str, Any]]: + """Get all votes from a specific user""" + logger.info(LogFormatter.info(f"Fetching votes for user: {user_id}")) + + # Check if we need to refresh votes + if (datetime.now(timezone.utc) - self._last_sync).total_seconds() > self._sync_interval: + logger.info(LogFormatter.info("Cache expired, refreshing votes...")) + await self._check_for_new_votes() + + votes = self._votes_by_user.get(user_id, []) + logger.info(LogFormatter.success(f"Found {len(votes):,} votes")) + return votes + + async def get_model_votes(self, model_id: str) -> Dict[str, Any]: + """Get all votes for a specific model""" + logger.info(LogFormatter.info(f"Fetching votes for model: {model_id}")) + + # Check if we need to refresh votes + if (datetime.now(timezone.utc) - self._last_sync).total_seconds() > self._sync_interval: + logger.info(LogFormatter.info("Cache expired, refreshing votes...")) + await self._check_for_new_votes() + + votes = self._votes_by_model.get(model_id, []) + + # Group votes by revision and precision + votes_by_config = {} + for vote in votes: + revision = vote.get("revision", "main") + precision = vote.get("precision", "unknown") + config_key = f"{revision}_{precision}" + if config_key not in votes_by_config: + votes_by_config[config_key] = { + "revision": revision, + "precision": precision, + "count": 0 + } + votes_by_config[config_key]["count"] += 1 + + stats = { + "Total_Votes": len(votes), + **{f"Config_{k}": v["count"] for k, v in votes_by_config.items()} + } + + logger.info(LogFormatter.section("VOTE STATISTICS")) + for line in LogFormatter.stats(stats): + logger.info(line) + + return { + "total_votes": len(votes), + "votes_by_config": votes_by_config, + "votes": votes + } + + async def _get_model_revision(self, model_id: str) -> str: + """Get current revision of a model with retries""" + logger.info(f"Getting revision for model: {model_id}") + for attempt in range(self._max_retries): + try: + model_info = await asyncio.to_thread(self.hf_api.model_info, model_id) + logger.info(f"Successfully got revision {model_info.sha} for model {model_id}") + return model_info.sha + except Exception as e: + logger.error(f"Error getting model revision for {model_id} (attempt {attempt + 1}): {str(e)}") + if attempt < self._max_retries - 1: + retry_delay = self._retry_delay * (attempt + 1) + logger.info(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + else: + logger.warning(f"Using 'main' as fallback revision for {model_id} after {self._max_retries} failed attempts") + return "main" + + async def add_vote(self, model_id: str, user_id: str, vote_type: str, vote_data: Dict[str, Any] = None) -> Dict[str, Any]: + """Add a vote for a model""" + try: + self._log_repo_operation("add", f"{HF_ORGANIZATION}/votes", f"Adding {vote_type} vote for {model_id} by {user_id}") + logger.info(LogFormatter.section("NEW VOTE")) + stats = { + "Model": model_id, + "User": user_id, + "Type": vote_type, + "Config": vote_data or {} + } + for line in LogFormatter.tree(stats, "Vote Details"): + logger.info(line) + + # Use provided configuration or fallback to model info + precision = None + revision = None + + if vote_data: + precision = vote_data.get("precision") + revision = vote_data.get("revision") + + # If any info is missing, try to get it from model info + if not all([precision, revision]): + try: + model_info = await asyncio.to_thread(self.hf_api.model_info, model_id) + model_card_data = model_info.cardData if hasattr(model_info, 'cardData') else {} + + if not precision: + precision = model_card_data.get("precision", "unknown") + if not revision: + revision = model_info.sha + except Exception as e: + logger.warning(LogFormatter.warning(f"Failed to get model info: {str(e)}. Using default values.")) + precision = precision or "unknown" + revision = revision or "main" + + # Check if vote already exists with this configuration + check_tuple = (model_id, revision, user_id, precision) + + if check_tuple in self.vote_check_set: + raise ValueError(f"Vote already recorded for this model configuration (precision: {precision}, revision: {revision[:7] if revision else 'unknown'})") + + vote = { + "model": model_id, + "revision": revision, + "username": user_id, + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "vote_type": vote_type, + "precision": precision + } + + # Update local storage + with open(self.votes_file, "a") as f: + f.write(json.dumps(vote) + "\n") + + self._add_vote_to_memory(vote) + self.votes_to_upload.append(vote) + + stats = { + "Status": "Success", + "Queue_Size": len(self.votes_to_upload), + "Model_Config": { + "Precision": precision, + "Revision": revision[:7] if revision else "unknown" + } + } + for line in LogFormatter.stats(stats): + logger.info(line) + + # Force immediate sync + logger.info(LogFormatter.info("Forcing immediate sync with hub")) + await self._sync_with_hub() + + return {"status": "success", "message": "Vote added successfully"} + + except Exception as e: + logger.error(LogFormatter.error("Failed to add vote", e)) + raise \ No newline at end of file diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69a93acb760828c13400cfcd19da2822dfd83e5e --- /dev/null +++ b/backend/app/utils/__init__.py @@ -0,0 +1,3 @@ +from . import model_validation + +__all__ = ["model_validation"] diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3a720f0c226faa0d0390a0c561be75db0194ca7f --- /dev/null +++ b/backend/app/utils/logging.py @@ -0,0 +1,3 @@ +from app.core.formatting import LogFormatter + +__all__ = ['LogFormatter'] \ No newline at end of file diff --git a/backend/app/utils/model_validation.py b/backend/app/utils/model_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..7cec5e092d07a0759deecbe5a4afdda4471bbf19 --- /dev/null +++ b/backend/app/utils/model_validation.py @@ -0,0 +1,266 @@ +import json +import logging +import asyncio +from typing import Tuple, Optional, Dict, Any +from datasets import load_dataset +from huggingface_hub import HfApi, ModelCard, hf_hub_download +from huggingface_hub import hf_api +from transformers import AutoConfig, AutoTokenizer +from app.config.base import HF_TOKEN +from app.config.hf_config import OFFICIAL_PROVIDERS_REPO +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +class ModelValidator: + def __init__(self): + self.token = HF_TOKEN + self.api = HfApi(token=self.token) + self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + + async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]: + """Check if model has a valid model card""" + try: + logger.info(LogFormatter.info(f"Checking model card for {model_id}")) + + # Get model card content using ModelCard.load + try: + model_card = await asyncio.to_thread( + ModelCard.load, + model_id + ) + logger.info(LogFormatter.success("Model card found")) + except Exception as e: + error_msg = "Please add a model card to your model to explain how you trained/fine-tuned it." + logger.error(LogFormatter.error(error_msg, e)) + return False, error_msg, None + + # Check license in model card data + if model_card.data.license is None and not ("license_name" in model_card.data and "license_link" in model_card.data): + error_msg = "License not found. Please add a license to your model card using the `license` metadata or a `license_name`/`license_link` pair." + logger.warning(LogFormatter.warning(error_msg)) + return False, error_msg, None + + # Enforce card content length + if len(model_card.text) < 200: + error_msg = "Please add a description to your model card, it is too short." + logger.warning(LogFormatter.warning(error_msg)) + return False, error_msg, None + + logger.info(LogFormatter.success("Model card validation passed")) + return True, "", model_card + + except Exception as e: + error_msg = "Failed to validate model card" + logger.error(LogFormatter.error(error_msg, e)) + return False, str(e), None + + async def get_safetensors_metadata(self, model_id: str, is_adapter: bool = False, revision: str = "main") -> Optional[Dict]: + """Get metadata from a safetensors file""" + try: + if is_adapter: + metadata = await asyncio.to_thread( + hf_api.parse_safetensors_file_metadata, + model_id, + "adapter_model.safetensors", + token=self.token, + revision=revision, + ) + else: + metadata = await asyncio.to_thread( + hf_api.get_safetensors_metadata, + repo_id=model_id, + token=self.token, + revision=revision, + ) + return metadata + + except Exception as e: + logger.error(f"Failed to get safetensors metadata: {str(e)}") + return None + + async def get_model_size( + self, + model_info: Any, + precision: str, + base_model: str, + revision: str + ) -> Tuple[Optional[float], Optional[str]]: + """Get model size in billions of parameters""" + try: + logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}")) + + # Check if model is adapter + is_adapter = any(s.rfilename == "adapter_config.json" for s in model_info.siblings if hasattr(s, 'rfilename')) + + # Try to get size from safetensors first + model_size = None + + if is_adapter and base_model: + # For adapters, we need both adapter and base model sizes + adapter_meta = await self.get_safetensors_metadata(model_info.id, is_adapter=True, revision=revision) + base_meta = await self.get_safetensors_metadata(base_model, revision="main") + + if adapter_meta and base_meta: + adapter_size = sum(adapter_meta.parameter_count.values()) + base_size = sum(base_meta.parameter_count.values()) + model_size = adapter_size + base_size + else: + # For regular models, just get the model size + meta = await self.get_safetensors_metadata(model_info.id, revision=revision) + if meta: + model_size = sum(meta.parameter_count.values()) # total params + + if model_size is None: + # If model size could not be determined, return an error + return None, "Model size could not be determined" + + # Adjust size for GPTQ models + size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 + model_size = model_size / 1e9 # Convert to billions, assuming float16 + model_size = round(size_factor * model_size, 3) + + logger.info(LogFormatter.success(f"Model size: {model_size}B parameters")) + return model_size, None + + except Exception as e: + logger.error(LogFormatter.error(f"Error while determining model size: {e}")) + return None, str(e) + + + async def check_chat_template( + self, + model_id: str, + revision: str + ) -> Tuple[bool, Optional[str]]: + """Check if model has a valid chat template""" + try: + logger.info(LogFormatter.info(f"Checking chat template for {model_id}")) + + try: + config_file = await asyncio.to_thread( + hf_hub_download, + repo_id=model_id, + filename="tokenizer_config.json", + revision=revision, + repo_type="model" + ) + + with open(config_file, 'r') as f: + tokenizer_config = json.load(f) + + if 'chat_template' not in tokenizer_config: + error_msg = f"The model {model_id} doesn't have a chat_template in its tokenizer_config.json. Please add a chat_template before submitting or submit without it." + logger.error(LogFormatter.error(error_msg)) + return False, error_msg + + logger.info(LogFormatter.success("Valid chat template found")) + return True, None + + except Exception as e: + error_msg = f"Error checking chat_template: {str(e)}" + logger.error(LogFormatter.error(error_msg)) + return False, error_msg + + except Exception as e: + error_msg = "Failed to check chat template" + logger.error(LogFormatter.error(error_msg, e)) + return False, str(e) + + async def is_model_on_hub( + self, + model_name: str, + revision: str, + test_tokenizer: bool = False, + trust_remote_code: bool = False + ) -> Tuple[bool, Optional[str], Optional[Any]]: + """Check if model exists and is properly configured on the Hub""" + try: + config = await asyncio.to_thread( + AutoConfig.from_pretrained, + model_name, + revision=revision, + trust_remote_code=trust_remote_code, + token=self.token, + force_download=True + ) + + if test_tokenizer: + try: + await asyncio.to_thread( + AutoTokenizer.from_pretrained, + model_name, + revision=revision, + trust_remote_code=trust_remote_code, + token=self.token + ) + except ValueError as e: + return False, f"The tokenizer is not available in an official Transformers release: {e}", None + except Exception: + return False, "The tokenizer cannot be loaded. Ensure the tokenizer class is part of a stable Transformers release and correctly configured.", None + + return True, None, config + + except ValueError: + return False, "The model requires `trust_remote_code=True` to launch, and for safety reasons, we don't accept such models automatically.", None + except Exception as e: + if "You are trying to access a gated repo." in str(e): + return True, "The model is gated and requires special access permissions.", None + return False, f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", None + + async def check_official_provider_status( + self, + model_id: str, + existing_models: Dict[str, list] + ) -> Tuple[bool, Optional[str]]: + """ + Check if model is from official provider and has finished submission. + + Args: + model_id: The model identifier (org/model-name) + existing_models: Dictionary of models by status from get_models() + + Returns: + Tuple[bool, Optional[str]]: (is_valid, error_message) + """ + try: + logger.info(LogFormatter.info(f"Checking official provider status for {model_id}")) + + # Get model organization + model_org = model_id.split('/')[0] if '/' in model_id else None + + if not model_org: + return True, None + + # Load official providers dataset + dataset = load_dataset(OFFICIAL_PROVIDERS_REPO) + official_providers = dataset["train"][0]["CURATED_SET"] + + # Check if model org is in official providers + is_official = model_org in official_providers + + if is_official: + logger.info(LogFormatter.info(f"Model organization '{model_org}' is an official provider")) + + # Check for finished submissions + if "finished" in existing_models: + for model in existing_models["finished"]: + if model["name"] == model_id: + error_msg = ( + f"Model {model_id} is an official provider model " + f"with a completed evaluation. " + f"To re-evaluate, please open a discussion." + ) + logger.error(LogFormatter.error("Validation failed", error_msg)) + return False, error_msg + + logger.info(LogFormatter.success("No finished submission found for this official provider model")) + else: + logger.info(LogFormatter.info(f"Model organization '{model_org}' is not an official provider")) + + return True, None + + except Exception as e: + error_msg = f"Failed to check official provider status: {str(e)}" + logger.error(LogFormatter.error(error_msg)) + return False, error_msg diff --git a/backend/pyproject.toml b/backend/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..2fb3fbb2bcdfbfaf6d8f27d1f2e0bec595cbc389 --- /dev/null +++ b/backend/pyproject.toml @@ -0,0 +1,31 @@ +[tool.poetry] +name = "llm-leaderboard-backend" +version = "0.1.0" +description = "Backend for the Open LLM Leaderboard" +authors = ["Your Name "] + +[tool.poetry.dependencies] +python = "^3.12" +fastapi = "^0.115.6" +uvicorn = {extras = ["standard"], version = "^0.34.0"} +numpy = "^2.2.0" +pandas = "^2.2.3" +datasets = "^3.3.2" +pyarrow = "^18.1.0" +python-multipart = "^0.0.20" +huggingface-hub = "0.29.1" +transformers = "4.49.0" +safetensors = "^0.5.3" +aiofiles = "^24.1.0" +fastapi-cache2 = "^0.2.1" +python-dotenv = "^1.0.1" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.3.4" +black = "^24.10.0" +isort = "^5.13.2" +flake8 = "^6.1.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/backend/utils/analyze_prod_datasets.py b/backend/utils/analyze_prod_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..346d4f7dc543c8ea3e08ab7124d3008e2a5530b5 --- /dev/null +++ b/backend/utils/analyze_prod_datasets.py @@ -0,0 +1,170 @@ +import os +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, List +from huggingface_hub import HfApi +from dotenv import load_dotenv +from app.config.hf_config import HF_ORGANIZATION + +# Get the backend directory path +BACKEND_DIR = Path(__file__).parent.parent +ROOT_DIR = BACKEND_DIR.parent + +# Load environment variables from .env file in root directory +load_dotenv(ROOT_DIR / ".env") + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize Hugging Face API +HF_TOKEN = os.getenv("HF_TOKEN") +if not HF_TOKEN: + raise ValueError("HF_TOKEN not found in environment variables") +api = HfApi(token=HF_TOKEN) + +def analyze_dataset(repo_id: str) -> Dict[str, Any]: + """Analyze a dataset and return statistics""" + try: + # Get dataset info + dataset_info = api.dataset_info(repo_id=repo_id) + + # Get file list + files = api.list_repo_files(repo_id, repo_type="dataset") + + # Get last commit info + commits = api.list_repo_commits(repo_id, repo_type="dataset") + last_commit = next(commits, None) + + # Count lines in jsonl files + total_entries = 0 + for file in files: + if file.endswith('.jsonl'): + try: + # Download file content + content = api.hf_hub_download( + repo_id=repo_id, + filename=file, + repo_type="dataset" + ) + + # Count lines + with open(content, 'r') as f: + for _ in f: + total_entries += 1 + + except Exception as e: + logger.error(f"Error processing file {file}: {str(e)}") + continue + + # Special handling for requests dataset + if repo_id == f"{HF_ORGANIZATION}/requests": + pending_count = 0 + completed_count = 0 + + try: + content = api.hf_hub_download( + repo_id=repo_id, + filename="eval_requests.jsonl", + repo_type="dataset" + ) + + with open(content, 'r') as f: + for line in f: + try: + entry = json.loads(line) + if entry.get("status") == "pending": + pending_count += 1 + elif entry.get("status") == "completed": + completed_count += 1 + except json.JSONDecodeError: + continue + + except Exception as e: + logger.error(f"Error analyzing requests: {str(e)}") + + # Build response + response = { + "id": repo_id, + "last_modified": last_commit.created_at if last_commit else None, + "total_entries": total_entries, + "file_count": len(files), + "size_bytes": dataset_info.size_in_bytes, + "downloads": dataset_info.downloads, + } + + # Add request-specific info if applicable + if repo_id == f"{HF_ORGANIZATION}/requests": + response.update({ + "pending_requests": pending_count, + "completed_requests": completed_count + }) + + return response + + except Exception as e: + logger.error(f"Error analyzing dataset {repo_id}: {str(e)}") + return { + "id": repo_id, + "error": str(e) + } + +def main(): + """Main function to analyze all datasets""" + try: + # List of datasets to analyze + datasets = [ + { + "id": f"{HF_ORGANIZATION}/contents", + "description": "Aggregated results" + }, + { + "id": f"{HF_ORGANIZATION}/requests", + "description": "Evaluation requests" + }, + { + "id": f"{HF_ORGANIZATION}/votes", + "description": "User votes" + }, + { + "id": f"{HF_ORGANIZATION}/official-providers", + "description": "Highlighted models" + } + ] + + # Analyze each dataset + results = [] + for dataset in datasets: + logger.info(f"\nAnalyzing {dataset['description']} ({dataset['id']})...") + result = analyze_dataset(dataset['id']) + results.append(result) + + if 'error' in result: + logger.error(f"❌ Error: {result['error']}") + else: + logger.info(f"✓ {result['total_entries']} entries") + logger.info(f"✓ {result['file_count']} files") + logger.info(f"✓ {result['size_bytes'] / 1024:.1f} KB") + logger.info(f"✓ {result['downloads']} downloads") + + if 'pending_requests' in result: + logger.info(f"✓ {result['pending_requests']} pending requests") + logger.info(f"✓ {result['completed_requests']} completed requests") + + if result['last_modified']: + last_modified = datetime.fromisoformat(result['last_modified'].replace('Z', '+00:00')) + logger.info(f"✓ Last modified: {last_modified.strftime('%Y-%m-%d %H:%M:%S')}") + + return results + + except Exception as e: + logger.error(f"Global error: {str(e)}") + return [] + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backend/utils/analyze_prod_models.py b/backend/utils/analyze_prod_models.py new file mode 100644 index 0000000000000000000000000000000000000000..90a066dbb76e98ae1e13f1e969f527b695146cce --- /dev/null +++ b/backend/utils/analyze_prod_models.py @@ -0,0 +1,106 @@ +import os +import json +import logging +from datetime import datetime +from pathlib import Path +from huggingface_hub import HfApi +from dotenv import load_dotenv +from app.config.hf_config import HF_ORGANIZATION + +# Get the backend directory path +BACKEND_DIR = Path(__file__).parent.parent +ROOT_DIR = BACKEND_DIR.parent + +# Load environment variables from .env file in root directory +load_dotenv(ROOT_DIR / ".env") + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize Hugging Face API +HF_TOKEN = os.getenv("HF_TOKEN") +if not HF_TOKEN: + raise ValueError("HF_TOKEN not found in environment variables") +api = HfApi(token=HF_TOKEN) + +def count_evaluated_models(): + """Count the number of evaluated models""" + try: + # Get dataset info + dataset_info = api.dataset_info(repo_id=f"{HF_ORGANIZATION}/contents", repo_type="dataset") + + # Get file list + files = api.list_repo_files(f"{HF_ORGANIZATION}/contents", repo_type="dataset") + + # Get last commit info + commits = api.list_repo_commits(f"{HF_ORGANIZATION}/contents", repo_type="dataset") + last_commit = next(commits, None) + + # Count lines in jsonl files + total_entries = 0 + for file in files: + if file.endswith('.jsonl'): + try: + # Download file content + content = api.hf_hub_download( + repo_id=f"{HF_ORGANIZATION}/contents", + filename=file, + repo_type="dataset" + ) + + # Count lines + with open(content, 'r') as f: + for _ in f: + total_entries += 1 + + except Exception as e: + logger.error(f"Error processing file {file}: {str(e)}") + continue + + # Build response + response = { + "total_models": total_entries, + "last_modified": last_commit.created_at if last_commit else None, + "file_count": len(files), + "size_bytes": dataset_info.size_in_bytes, + "downloads": dataset_info.downloads + } + + return response + + except Exception as e: + logger.error(f"Error counting evaluated models: {str(e)}") + return { + "error": str(e) + } + +def main(): + """Main function to count evaluated models""" + try: + logger.info("\nAnalyzing evaluated models...") + result = count_evaluated_models() + + if 'error' in result: + logger.error(f"❌ Error: {result['error']}") + else: + logger.info(f"✓ {result['total_models']} models evaluated") + logger.info(f"✓ {result['file_count']} files") + logger.info(f"✓ {result['size_bytes'] / 1024:.1f} KB") + logger.info(f"✓ {result['downloads']} downloads") + + if result['last_modified']: + last_modified = datetime.fromisoformat(result['last_modified'].replace('Z', '+00:00')) + logger.info(f"✓ Last modified: {last_modified.strftime('%Y-%m-%d %H:%M:%S')}") + + return result + + except Exception as e: + logger.error(f"Global error: {str(e)}") + return {"error": str(e)} + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backend/utils/fix_wrong_model_size.py b/backend/utils/fix_wrong_model_size.py new file mode 100644 index 0000000000000000000000000000000000000000..3b464f873c6465d077e4da735935b14884ace254 --- /dev/null +++ b/backend/utils/fix_wrong_model_size.py @@ -0,0 +1,110 @@ +import os +import json +import pytz +import logging +import asyncio +from datetime import datetime +from pathlib import Path +import huggingface_hub +from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError +from dotenv import load_dotenv +from git import Repo +from datetime import datetime +from tqdm.auto import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from app.config.hf_config import HF_TOKEN, API + +from app.utils.model_validation import ModelValidator + +huggingface_hub.logging.set_verbosity_error() +huggingface_hub.utils.disable_progress_bars() + +logging.basicConfig( + level=logging.ERROR, + format='%(message)s' +) +logger = logging.getLogger(__name__) +load_dotenv() + +validator = ModelValidator() + +def get_changed_files(repo_path, start_date, end_date): + repo = Repo(repo_path) + start = datetime.strptime(start_date, '%Y-%m-%d') + end = datetime.strptime(end_date, '%Y-%m-%d') + + changed_files = set() + pbar = tqdm(repo.iter_commits(), desc=f"Reading commits from {end_date} to {start_date}") + for commit in pbar: + commit_date = datetime.fromtimestamp(commit.committed_date) + pbar.set_postfix_str(f"Commit date: {commit_date}") + if start <= commit_date <= end: + changed_files.update(item.a_path for item in commit.diff(commit.parents[0])) + + if commit_date < start: + break + + return changed_files + + +def read_json(repo_path, file): + with open(f"{repo_path}/{file}") as file: + return json.load(file) + + +def write_json(repo_path, file, content): + with open(f"{repo_path}/{file}", "w") as file: + json.dump(content, file, indent=2) + + +def main(): + requests_path = "/requests" + start_date = "2024-12-09" + end_date = "2025-01-07" + + changed_files = get_changed_files(requests_path, start_date, end_date) + + for file in tqdm(changed_files): + try: + request_data = read_json(requests_path, file) + except FileNotFoundError as e: + tqdm.write(f"File {file} not found") + continue + + try: + model_info = API.model_info( + repo_id=request_data["model"], + revision=request_data["revision"], + token=HF_TOKEN + ) + except (RepositoryNotFoundError, RevisionNotFoundError) as e: + tqdm.write(f"Model info for {request_data["model"]} not found") + continue + + with logging_redirect_tqdm(): + new_model_size, error = asyncio.run(validator.get_model_size( + model_info=model_info, + precision=request_data["precision"], + base_model=request_data["base_model"], + revision=request_data["revision"] + )) + + if error: + tqdm.write(f"Error getting model size info for {request_data["model"]}, {error}") + continue + + old_model_size = request_data["params"] + if old_model_size != new_model_size: + if new_model_size > 100: + tqdm.write(f"Model: {request_data["model"]}, size is more 100B: {new_model_size}") + + tqdm.write(f"Model: {request_data["model"]}, old size: {request_data["params"]} new size: {new_model_size}") + tqdm.write(f"Updating request file {file}") + + request_data["params"] = new_model_size + write_json(requests_path, file, content=request_data) + + +if __name__ == "__main__": + main() diff --git a/backend/utils/last_activity.py b/backend/utils/last_activity.py new file mode 100644 index 0000000000000000000000000000000000000000..9f403ef0d223f79c9f7d2633ecbee5c3044ed5ae --- /dev/null +++ b/backend/utils/last_activity.py @@ -0,0 +1,164 @@ +import os +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, List, Tuple +from huggingface_hub import HfApi +from dotenv import load_dotenv + +# Get the backend directory path +BACKEND_DIR = Path(__file__).parent.parent +ROOT_DIR = BACKEND_DIR.parent + +# Load environment variables from .env file in root directory +load_dotenv(ROOT_DIR / ".env") + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(message)s' +) +logger = logging.getLogger(__name__) + +# Initialize Hugging Face API +HF_TOKEN = os.getenv("HF_TOKEN") +if not HF_TOKEN: + raise ValueError("HF_TOKEN not found in environment variables") +api = HfApi(token=HF_TOKEN) + +# Default organization +HF_ORGANIZATION = os.getenv('HF_ORGANIZATION', 'open-llm-leaderboard') + +def get_last_votes(limit: int = 5) -> List[Dict]: + """Get the last votes from the votes dataset""" + try: + logger.info("\nFetching last votes...") + + # Download and read votes file + logger.info("Downloading votes file...") + votes_file = api.hf_hub_download( + repo_id=f"{HF_ORGANIZATION}/votes", + filename="votes_data.jsonl", + repo_type="dataset" + ) + + logger.info("Reading votes file...") + votes = [] + with open(votes_file, 'r') as f: + for line in f: + try: + vote = json.loads(line) + votes.append(vote) + except json.JSONDecodeError: + continue + + # Sort by timestamp and get last n votes + logger.info("Sorting votes...") + votes.sort(key=lambda x: x.get('timestamp', ''), reverse=True) + last_votes = votes[:limit] + + logger.info(f"✓ Found {len(last_votes)} recent votes") + return last_votes + + except Exception as e: + logger.error(f"Error reading votes: {str(e)}") + return [] + +def get_last_models(limit: int = 5) -> List[Dict]: + """Get the last models from the requests dataset using commit history""" + try: + logger.info("\nFetching last model submissions...") + + # Get commit history + logger.info("Getting commit history...") + commits = list(api.list_repo_commits( + repo_id=f"{HF_ORGANIZATION}/requests", + repo_type="dataset" + )) + logger.info(f"Found {len(commits)} commits") + + # Track processed files to avoid duplicates + processed_files = set() + models = [] + + # Process commits until we have enough models + for i, commit in enumerate(commits): + logger.info(f"Processing commit {i+1}/{len(commits)} ({commit.created_at})") + + # Look at added/modified files in this commit + files_to_process = [f for f in (commit.added + commit.modified) if f.endswith('.json')] + if files_to_process: + logger.info(f"Found {len(files_to_process)} JSON files in commit") + + for file in files_to_process: + if file in processed_files: + continue + + processed_files.add(file) + logger.info(f"Downloading {file}...") + + try: + # Download and read the file + content = api.hf_hub_download( + repo_id=f"{HF_ORGANIZATION}/requests", + filename=file, + repo_type="dataset" + ) + + with open(content, 'r') as f: + model_data = json.load(f) + models.append(model_data) + logger.info(f"✓ Added model {model_data.get('model', 'Unknown')}") + + if len(models) >= limit: + logger.info("Reached desired number of models") + break + + except Exception as e: + logger.error(f"Error reading file {file}: {str(e)}") + continue + + if len(models) >= limit: + break + + logger.info(f"✓ Found {len(models)} recent model submissions") + return models + + except Exception as e: + logger.error(f"Error reading models: {str(e)}") + return [] + +def main(): + """Display last activities from the leaderboard""" + try: + # Get last votes + logger.info("\n=== Last Votes ===") + last_votes = get_last_votes() + if last_votes: + for vote in last_votes: + logger.info(f"\nModel: {vote.get('model')}") + logger.info(f"User: {vote.get('username')}") + logger.info(f"Timestamp: {vote.get('timestamp')}") + else: + logger.info("No votes found") + + # Get last model submissions + logger.info("\n=== Last Model Submissions ===") + last_models = get_last_models() + if last_models: + for model in last_models: + logger.info(f"\nModel: {model.get('model')}") + logger.info(f"Submitter: {model.get('sender', 'Unknown')}") + logger.info(f"Status: {model.get('status', 'Unknown')}") + logger.info(f"Submission Time: {model.get('submitted_time', 'Unknown')}") + logger.info(f"Precision: {model.get('precision', 'Unknown')}") + logger.info(f"Weight Type: {model.get('weight_type', 'Unknown')}") + else: + logger.info("No models found") + + except Exception as e: + logger.error(f"Global error: {str(e)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backend/utils/sync_datasets_locally.py b/backend/utils/sync_datasets_locally.py new file mode 100644 index 0000000000000000000000000000000000000000..c06326899e79b974d9d621e6531ea8f6b9563f9c --- /dev/null +++ b/backend/utils/sync_datasets_locally.py @@ -0,0 +1,130 @@ +import os +import shutil +import tempfile +import logging +from pathlib import Path +from huggingface_hub import HfApi, snapshot_download, upload_folder, create_repo +from dotenv import load_dotenv + +# Configure source and destination usernames +SOURCE_USERNAME = "open-llm-leaderboard" +DESTINATION_USERNAME = "tfrere" + +# Get the backend directory path +BACKEND_DIR = Path(__file__).parent.parent +ROOT_DIR = BACKEND_DIR.parent + +# Load environment variables from .env file in root directory +load_dotenv(ROOT_DIR / ".env") + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(message)s' +) +logger = logging.getLogger(__name__) + +# List of dataset names to sync +DATASET_NAMES = [ + "votes", + "results", + "requests", + "contents", + "official-providers", +] + +# Build list of datasets with their source and destination paths +DATASETS = [ + (name, f"{SOURCE_USERNAME}/{name}", f"{DESTINATION_USERNAME}/{name}") + for name in DATASET_NAMES +] + +# Initialize Hugging Face API +api = HfApi() + +def ensure_repo_exists(repo_id, token): + """Ensure the repository exists, create it if it doesn't""" + try: + api.repo_info(repo_id=repo_id, repo_type="dataset") + logger.info(f"✓ Repository {repo_id} already exists") + except Exception: + logger.info(f"Creating repository {repo_id}...") + create_repo( + repo_id=repo_id, + repo_type="dataset", + token=token, + private=True + ) + logger.info(f"✓ Repository {repo_id} created") + +def process_dataset(dataset_info, token): + """Process a single dataset""" + name, source_dataset, destination_dataset = dataset_info + try: + logger.info(f"\n📥 Processing dataset: {name}") + + # Ensure destination repository exists + ensure_repo_exists(destination_dataset, token) + + # Create a temporary directory for this dataset + with tempfile.TemporaryDirectory() as temp_dir: + try: + # List files in source dataset + logger.info(f"Listing files in {source_dataset}...") + files = api.list_repo_files(source_dataset, repo_type="dataset") + logger.info(f"Detected structure: {len(files)} files") + + # Download dataset + logger.info(f"Downloading from {source_dataset}...") + local_dir = snapshot_download( + repo_id=source_dataset, + repo_type="dataset", + local_dir=temp_dir, + token=token + ) + logger.info(f"✓ Download complete") + + # Upload to destination while preserving structure + logger.info(f"📤 Uploading to {destination_dataset}...") + api.upload_folder( + folder_path=local_dir, + repo_id=destination_dataset, + repo_type="dataset", + token=token + ) + logger.info(f"✅ {name} copied successfully!") + return True + + except Exception as e: + logger.error(f"❌ Error processing {name}: {str(e)}") + return False + + except Exception as e: + logger.error(f"❌ Error for {name}: {str(e)}") + return False + +def copy_datasets(): + try: + logger.info("🔑 Checking authentication...") + # Get token from .env file + token = os.getenv("HF_TOKEN") + if not token: + raise ValueError("HF_TOKEN not found in .env file") + + # Process datasets sequentially + results = [] + for dataset_info in DATASETS: + success = process_dataset(dataset_info, token) + results.append((dataset_info[0], success)) + + # Print final summary + logger.info("\n📊 Final summary:") + for dataset, success in results: + status = "✅ Success" if success else "❌ Failure" + logger.info(f"{dataset}: {status}") + + except Exception as e: + logger.error(f"❌ Global error: {str(e)}") + +if __name__ == "__main__": + copy_datasets() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..f2893d3eef1fc6b5af918d0debd30e9c45de5733 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,33 @@ +services: + backend: + build: + context: ./backend + dockerfile: Dockerfile.dev + args: + - HF_TOKEN=${HF_TOKEN} + ports: + - "${BACKEND_PORT:-8000}:8000" + volumes: + - ./backend:/app + environment: + - ENVIRONMENT=${ENVIRONMENT:-development} + - HF_TOKEN=${HF_TOKEN} + - HF_HOME=${HF_HOME:-/.cache} + command: uvicorn app.asgi:app --host 0.0.0.0 --port 8000 --reload + + frontend: + build: + context: ./frontend + dockerfile: Dockerfile.dev + ports: + - "${FRONTEND_PORT:-7860}:7860" + volumes: + - ./frontend:/app + - /app/node_modules + environment: + - NODE_ENV=${ENVIRONMENT:-development} + - CHOKIDAR_USEPOLLING=true + - PORT=${FRONTEND_PORT:-7860} + command: npm start + stdin_open: true + tty: true \ No newline at end of file diff --git a/frontend/Dockerfile.dev b/frontend/Dockerfile.dev new file mode 100644 index 0000000000000000000000000000000000000000..259f7c9d8746db26bee8ee531d96cbe0d619321e --- /dev/null +++ b/frontend/Dockerfile.dev @@ -0,0 +1,15 @@ +FROM node:18 + +WORKDIR /app + +# Install required global dependencies +RUN npm install -g react-scripts + +# Copy package.json and package-lock.json +COPY package*.json ./ + +# Install project dependencies +RUN npm install + +# Volume will be mounted here, no need for COPY +CMD ["npm", "start"] \ No newline at end of file diff --git a/frontend/README.md b/frontend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7ef4ff265f3c870efce128f47bdda8d266689a88 --- /dev/null +++ b/frontend/README.md @@ -0,0 +1,80 @@ +# Frontend - Open LLM Leaderboard 🏆 + +React interface for exploring and comparing open-source language models. + +## 🏗 Architecture + +```mermaid +flowchart TD + Client(["User Browser"]) --> Components["React Components"] + + subgraph Frontend + Components --> Context["Context Layer
• LeaderboardContext
• Global State"] + + API["API Layer
• /api/leaderboard/formatted
• TanStack Query"] --> |Data Feed| Context + + Context --> Hooks["Hooks Layer
• Data Processing
• Filtering
• Caching"] + + Hooks --> Features["Features
• Table Management
• Search & Filters
• Display Options"] + Features --> Cache["Cache Layer
• LocalStorage
• URL State"] + end + + API --> Backend["Backend Server"] + + style Backend fill:#f96,stroke:#333,stroke-width:2px +``` + +## ✨ Core Features + +- 🔍 **Search & Filters**: Real-time filtering, regex search, advanced filters +- 📊 **Data Visualization**: Interactive table, customizable columns, sorting +- 🔄 **State Management**: URL sync, client-side caching (5min TTL) +- 📱 **Responsive Design**: Mobile-friendly, dark/light themes + +## 🛠 Tech Stack + +- React 18 + Material-UI +- TanStack Query & Table +- React Router v6 + +## 📁 Project Structure + +``` +src/ +├── pages/ +│ └── LeaderboardPage/ +│ ├── components/ # UI Components +│ ├── context/ # Global State +│ └── hooks/ # Data Processing +├── components/ # Shared Components +└── utils/ # Helper Functions +``` + +## 🚀 Development + +```bash +# Install dependencies +npm install + +# Start development server +npm start + +# Production build +npm run build +``` + +## 🔧 Environment Variables + +```env +# API Configuration +REACT_APP_API_URL=http://localhost:8000 +REACT_APP_CACHE_DURATION=300000 # 5 minutes +``` + +## 🔄 Data Flow + +1. API fetches leaderboard data from backend +2. Context stores and manages global state +3. Hooks handle data processing and filtering +4. Components render based on processed data +5. Cache maintains user preferences and URL state diff --git a/frontend/package.json b/frontend/package.json new file mode 100644 index 0000000000000000000000000000000000000000..93de14fd49415a97be66fa06310e2a1249b85ad6 --- /dev/null +++ b/frontend/package.json @@ -0,0 +1,55 @@ +{ + "name": "open-llm-leaderboard", + "version": "0.1.0", + "private": true, + "dependencies": { + "@emotion/react": "^11.13.3", + "@emotion/styled": "^11.13.0", + "@huggingface/hub": "^0.14.0", + "@mui/icons-material": "^6.1.7", + "@mui/lab": "^6.0.0-beta.16", + "@mui/material": "^6.1.6", + "@mui/x-data-grid": "^7.22.2", + "@tanstack/react-query": "^5.62.2", + "@tanstack/react-table": "^8.20.5", + "@tanstack/react-virtual": "^3.10.9", + "@testing-library/jest-dom": "^5.17.0", + "@testing-library/react": "^13.4.0", + "@testing-library/user-event": "^13.5.0", + "compression": "^1.7.4", + "cors": "^2.8.5", + "express": "^4.18.2", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-router-dom": "^6.28.0", + "react-scripts": "5.0.1", + "serve-static": "^1.15.0", + "web-vitals": "^2.1.4" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build", + "test": "react-scripts test", + "eject": "react-scripts eject", + "serve": "node server.js" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "proxy": "http://backend:8000" +} diff --git a/frontend/public/index.html b/frontend/public/index.html new file mode 100644 index 0000000000000000000000000000000000000000..a8591a1fee67f55b23e147afb2b8a5e7afc5005a --- /dev/null +++ b/frontend/public/index.html @@ -0,0 +1,96 @@ + + + + + + + + + + + + + + + + + + + + + + + + Open LLM Leaderboard - Compare Open Source Large Language Models + + + + + + +
+ + + diff --git a/frontend/public/logo256.png b/frontend/public/logo256.png new file mode 100644 index 0000000000000000000000000000000000000000..58547e134af0ac1200a4608fb1c800b3e8e9ddf1 Binary files /dev/null and b/frontend/public/logo256.png differ diff --git a/frontend/public/logo32.png b/frontend/public/logo32.png new file mode 100644 index 0000000000000000000000000000000000000000..1b6e8fbd42dd1bcc599649bf6f230fde89a6908a Binary files /dev/null and b/frontend/public/logo32.png differ diff --git a/frontend/public/og-image.jpg b/frontend/public/og-image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d4a3f3cb7d838489ef0a5dde1ce7c493273f98d Binary files /dev/null and b/frontend/public/og-image.jpg differ diff --git a/frontend/public/robots.txt b/frontend/public/robots.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9e57dc4d41b9b46e05112e9f45b7ea6ac0ba15e --- /dev/null +++ b/frontend/public/robots.txt @@ -0,0 +1,3 @@ +# https://www.robotstxt.org/robotstxt.html +User-agent: * +Disallow: diff --git a/frontend/server.js b/frontend/server.js new file mode 100644 index 0000000000000000000000000000000000000000..653befea69419568b117ce809871639d86d65581 --- /dev/null +++ b/frontend/server.js @@ -0,0 +1,85 @@ +const express = require("express"); +const cors = require("cors"); +const compression = require("compression"); +const path = require("path"); +const serveStatic = require("serve-static"); +const { createProxyMiddleware } = require("http-proxy-middleware"); + +const app = express(); +const port = process.env.PORT || 7860; +const apiPort = process.env.INTERNAL_API_PORT || 7861; + +// Enable CORS for all routes +app.use(cors()); + +// Enable GZIP compression +app.use(compression()); + +// Proxy all API requests to the Python backend +app.use( + "/api", + createProxyMiddleware({ + target: `http://127.0.0.1:${apiPort}`, + changeOrigin: true, + onError: (err, req, res) => { + console.error("Proxy Error:", err); + res.status(500).json({ error: "Proxy Error", details: err.message }); + }, + }) +); + +// Serve static files from the build directory +app.use( + express.static(path.join(__dirname, "build"), { + // Don't cache HTML files + setHeaders: (res, path) => { + if (path.endsWith(".html")) { + res.setHeader("Cache-Control", "no-cache, no-store, must-revalidate"); + res.setHeader("Pragma", "no-cache"); + res.setHeader("Expires", "0"); + } else { + // Cache other static resources for 1 year + res.setHeader("Cache-Control", "public, max-age=31536000"); + } + }, + }) +); + +// Middleware to preserve URL parameters +app.use((req, res, next) => { + // Don't interfere with API requests + if (req.url.startsWith("/api")) { + return next(); + } + + // Preserve original URL parameters + req.originalUrl = req.url; + next(); +}); + +// Handle all other routes by serving index.html +app.get("*", (req, res) => { + // Don't interfere with API requests + if (req.url.startsWith("/api")) { + return next(); + } + + // Headers for client-side routing + res.set({ + "Cache-Control": "no-cache, no-store, must-revalidate", + Pragma: "no-cache", + Expires: "0", + }); + + // Send index.html for all other routes + res.sendFile(path.join(__dirname, "build", "index.html")); +}); + +app.listen(port, "0.0.0.0", () => { + console.log( + `Frontend server is running on port ${port} in ${ + process.env.NODE_ENV || "development" + } mode` + ); + console.log(`API proxy target: http://127.0.0.1:${apiPort}`); +}); diff --git a/frontend/src/App.js b/frontend/src/App.js new file mode 100644 index 0000000000000000000000000000000000000000..3eccae0cac8ea671be9be4dafb53be5e1ea4f70a --- /dev/null +++ b/frontend/src/App.js @@ -0,0 +1,121 @@ +import React, { useEffect } from "react"; +import { + HashRouter as Router, + Routes, + Route, + useSearchParams, + useLocation, +} from "react-router-dom"; +import { ThemeProvider } from "@mui/material/styles"; +import { Box, CssBaseline } from "@mui/material"; +import Navigation from "./components/Navigation/Navigation"; +import LeaderboardPage from "./pages/LeaderboardPage/LeaderboardPage"; +import QuotePage from "./pages/QuotePage/QuotePage"; +import Footer from "./components/Footer/Footer"; +import getTheme from "./config/theme"; +import { useThemeMode } from "./hooks/useThemeMode"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import LeaderboardProvider from "./pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext"; + +const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: 1, + refetchOnWindowFocus: false, + }, + }, +}); + +function UrlHandler() { + const location = useLocation(); + const [searchParams] = useSearchParams(); + + // Synchroniser l'URL avec la page parente HF + useEffect(() => { + // Vérifier si nous sommes dans un iframe HF Space + const isHFSpace = window.location !== window.parent.location; + if (!isHFSpace) return; + + // Sync query and hash from this embedded app to the parent page URL + const queryString = window.location.search; + const hash = window.location.hash; + + // HF Spaces' special message type to update the query string and the hash in the parent page URL + window.parent.postMessage( + { + queryString, + hash, + }, + "https://huggingface.co" + ); + }, [location, searchParams]); + + // Read the updated hash reactively + useEffect(() => { + const handleHashChange = (event) => { + console.log("hash change event", event); + }; + + window.addEventListener("hashchange", handleHashChange); + return () => window.removeEventListener("hashchange", handleHashChange); + }, []); + + return null; +} + +function App() { + const { mode, toggleTheme } = useThemeMode(); + const theme = getTheme(mode); + + return ( +
+ + + + + + + + + + + } /> + } /> + + +
+ + + + + +
+ ); +} + +export default App; diff --git a/frontend/src/components/Footer/Footer.js b/frontend/src/components/Footer/Footer.js new file mode 100644 index 0000000000000000000000000000000000000000..2064e062f55de1cf477fd80211f2bd5d9835fb63 --- /dev/null +++ b/frontend/src/components/Footer/Footer.js @@ -0,0 +1,30 @@ +import React from "react"; +import { Box, Typography, Link } from "@mui/material"; + +const Footer = () => { + return ( + + + © 2024 Hugging Face - Open LLM Leaderboard - Made with 🤗 by the HF team + -{" "} + + huggingface.co + + + + ); +}; + +export default Footer; diff --git a/frontend/src/components/Logo/HFLogo.js b/frontend/src/components/Logo/HFLogo.js new file mode 100644 index 0000000000000000000000000000000000000000..e49263da5f52e62f50db806f6f295d94e75be47f --- /dev/null +++ b/frontend/src/components/Logo/HFLogo.js @@ -0,0 +1,19 @@ +import React from 'react'; + +const HFLogo = () => ( + + hg-logo + + +); + +export default HFLogo; \ No newline at end of file diff --git a/frontend/src/components/Logo/Logo.js b/frontend/src/components/Logo/Logo.js new file mode 100644 index 0000000000000000000000000000000000000000..55db4a876d67bdc378ac86c8a5aba2276ff6df33 --- /dev/null +++ b/frontend/src/components/Logo/Logo.js @@ -0,0 +1,56 @@ +import React from "react"; +import { useNavigate, useSearchParams, useLocation } from "react-router-dom"; +import { Box } from "@mui/material"; +import HFLogo from "./HFLogo"; +import { useLeaderboard } from "../../pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext"; + +const Logo = ({ height = "40px" }) => { + const navigate = useNavigate(); + const [searchParams, setSearchParams] = useSearchParams(); + const location = useLocation(); + const { actions } = useLeaderboard(); + + const handleReset = () => { + // Reset all leaderboard state first + actions.resetAll(); + + // Then clean URL in one go + if ( + location.pathname !== "/" || + searchParams.toString() !== "" || + location.hash !== "" + ) { + window.history.replaceState(null, "", "/"); + navigate("/", { replace: true, state: { skipUrlSync: true } }); + setSearchParams({}, { replace: true, state: { skipUrlSync: true } }); + } + }; + + return ( + + + + + + ); +}; + +export default Logo; diff --git a/frontend/src/components/Navigation/Navigation.js b/frontend/src/components/Navigation/Navigation.js new file mode 100644 index 0000000000000000000000000000000000000000..eabd32711d87e7d4434d36828e2cb50a1f311560 --- /dev/null +++ b/frontend/src/components/Navigation/Navigation.js @@ -0,0 +1,474 @@ +import React, { useState } from "react"; +import { + AppBar, + Toolbar, + Box, + Link as MuiLink, + IconButton, + Tooltip, + ButtonBase, + Typography, +} from "@mui/material"; +import { useLocation, useNavigate, useSearchParams } from "react-router-dom"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import LightModeOutlinedIcon from "@mui/icons-material/LightModeOutlined"; +import DarkModeOutlinedIcon from "@mui/icons-material/DarkModeOutlined"; +import { alpha } from "@mui/material/styles"; +import MenuIcon from "@mui/icons-material/Menu"; +import { Menu, MenuItem, useMediaQuery, useTheme } from "@mui/material"; + +const Navigation = ({ onToggleTheme, mode }) => { + const location = useLocation(); + const navigate = useNavigate(); + const [searchParams] = useSearchParams(); + const [anchorEl, setAnchorEl] = useState(null); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("md")); + const [hasChanged, setHasChanged] = useState(false); + + const handleThemeToggle = () => { + setHasChanged(true); + onToggleTheme(); + }; + + const iconStyle = { + fontSize: "1.125rem", + ...(hasChanged && { + animation: "rotateIn 0.3s cubic-bezier(0.4, 0, 0.2, 1)", + "@keyframes rotateIn": { + "0%": { + opacity: 0, + transform: + mode === "light" + ? "rotate(-90deg) scale(0.8)" + : "rotate(90deg) scale(0.8)", + }, + "100%": { + opacity: 1, + transform: "rotate(0) scale(1)", + }, + }, + }), + }; + + // Function to sync URL with parent HF page + const syncUrlWithParent = (queryString, hash) => { + // Check if we're in an HF Space iframe + const isHFSpace = window.location !== window.parent.location; + if (isHFSpace) { + try { + // Build complete URL with hash + const fullPath = `${queryString}${hash ? "#" + hash : ""}`; + window.parent.postMessage( + { + type: "urlUpdate", + path: fullPath, + }, + "https://huggingface.co" + ); + } catch (e) { + console.warn("Unable to sync URL with parent:", e); + } + } + }; + + const linkStyle = (isActive = false) => ({ + textDecoration: "none", + color: isActive ? "text.primary" : "text.secondary", + fontSize: "0.8125rem", + opacity: isActive ? 1 : 0.8, + display: "flex", + alignItems: "center", + gap: 0.5, + paddingBottom: "2px", + cursor: "pointer", + position: "relative", + "&:hover": { + opacity: 1, + color: "text.primary", + }, + "&::after": isActive + ? { + content: '""', + position: "absolute", + bottom: "-4px", + left: "0", + width: "100%", + height: "2px", + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.3 : 0.2 + ), + borderRadius: "2px", + } + : {}, + }); + + const Separator = () => ( + ({ + width: "4px", + height: "4px", + borderRadius: "100%", + backgroundColor: alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.2 : 0.15 + ), + })} + /> + ); + + const handleNavigation = (path) => (e) => { + e.preventDefault(); + const searchString = searchParams.toString(); + const queryString = searchString ? `?${searchString}` : ""; + const newPath = `${path}${queryString}`; + + // Local navigation via React Router + navigate(newPath); + + // If in HF Space, sync with parent + if (window.location !== window.parent.location) { + syncUrlWithParent(queryString, newPath); + } + }; + + const handleMenuOpen = (event) => { + setAnchorEl(event.currentTarget); + }; + + const handleMenuClose = () => { + setAnchorEl(null); + }; + + return ( + + + {isMobile ? ( + + + + + + + `1px solid ${alpha(theme.palette.divider, 0.1)}`, + backgroundColor: (theme) => + theme.palette.mode === "dark" + ? alpha(theme.palette.background.paper, 0.8) + : theme.palette.background.paper, + backdropFilter: "blur(20px)", + "& .MuiList-root": { + py: 1, + }, + "& .MuiMenuItem-root": { + px: 2, + py: 1, + fontSize: "0.8125rem", + color: "text.secondary", + transition: "all 0.2s ease-in-out", + position: "relative", + "&:hover": { + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + color: "text.primary", + }, + "&.Mui-selected": { + backgroundColor: "transparent", + color: "text.primary", + "&::after": { + content: '""', + position: "absolute", + left: "8px", + width: "4px", + height: "100%", + top: "0", + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.3 : 0.2 + ), + borderRadius: "2px", + }, + "&:hover": { + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + }, + }, + }, + }, + }} + transformOrigin={{ horizontal: "left", vertical: "top" }} + anchorOrigin={{ horizontal: "left", vertical: "bottom" }} + > + {/* Navigation Section */} + + + Navigation + + + { + handleNavigation("/")(e); + handleMenuClose(); + }} + selected={location.pathname === "/"} + > + Leaderboard + + { + handleNavigation("/quote")(e); + handleMenuClose(); + }} + selected={location.pathname === "/quote"} + > + Citations + + + {/* Separator */} + + `1px solid ${alpha(theme.palette.divider, 0.1)}`, + }} + /> + + {/* External Links Section */} + + + External links + + + + Compare models + + + + About + + + + + + ({ + color: "text.secondary", + borderRadius: "100%", + padding: 0, + width: "36px", + height: "36px", + display: "flex", + alignItems: "center", + justifyContent: "center", + transition: "all 0.2s ease-in-out", + "&:hover": { + color: "text.primary", + backgroundColor: alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + }, + "&.MuiButtonBase-root": { + overflow: "hidden", + }, + "& .MuiTouchRipple-root": { + color: alpha(theme.palette.text.primary, 0.3), + }, + })} + > + {mode === "light" ? ( + + ) : ( + + )} + + + + ) : ( + // Desktop version + + {/* Internal navigation */} + + + Leaderboard + + + Citations + + + + + + {/* External links */} + + + Compare models + + + + About + + + + + + + {/* Dark mode toggle */} + + ({ + color: "text.secondary", + borderRadius: "100%", + padding: 0, + width: "36px", + height: "36px", + display: "flex", + alignItems: "center", + justifyContent: "center", + transition: "all 0.2s ease-in-out", + "&:hover": { + color: "text.primary", + backgroundColor: alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + }, + "&.MuiButtonBase-root": { + overflow: "hidden", + }, + "& .MuiTouchRipple-root": { + color: alpha(theme.palette.text.primary, 0.3), + }, + })} + > + {mode === "light" ? ( + + ) : ( + + )} + + + + )} + + + ); +}; + +export default Navigation; diff --git a/frontend/src/components/shared/AuthContainer.js b/frontend/src/components/shared/AuthContainer.js new file mode 100644 index 0000000000000000000000000000000000000000..ca79ed8645929ab583964e33be5c1810eef620ab --- /dev/null +++ b/frontend/src/components/shared/AuthContainer.js @@ -0,0 +1,168 @@ +import React from "react"; +import { + Box, + Typography, + Button, + Chip, + Stack, + Paper, + CircularProgress, + useTheme, + useMediaQuery, +} from "@mui/material"; +import HFLogo from "../Logo/HFLogo"; +import { useAuth } from "../../hooks/useAuth"; +import LogoutIcon from "@mui/icons-material/Logout"; +import { useNavigate } from "react-router-dom"; + +function AuthContainer({ actionText = "DO_ACTION" }) { + const { isAuthenticated, user, login, logout, loading } = useAuth(); + const navigate = useNavigate(); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + const handleLogout = () => { + if (isAuthenticated && logout) { + logout(); + navigate("/", { replace: true }); + window.location.reload(); + } + }; + + if (loading) { + return ( + + + + ); + } + + if (!isAuthenticated) { + return ( + + + Login to {actionText} + + + You need to be logged in with your Hugging Face account to{" "} + {actionText.toLowerCase()} + + + + ); + } + + return ( + + + + + Connected as {user?.username} + + + + + + + ); +} + +export default AuthContainer; diff --git a/frontend/src/components/shared/CodeBlock.js b/frontend/src/components/shared/CodeBlock.js new file mode 100644 index 0000000000000000000000000000000000000000..6f06f6eed1f6a17dd70334d3a7bb4d0ab897355c --- /dev/null +++ b/frontend/src/components/shared/CodeBlock.js @@ -0,0 +1,37 @@ +import React from 'react'; +import { Box, IconButton } from '@mui/material'; +import ContentCopyIcon from '@mui/icons-material/ContentCopy'; + +const CodeBlock = ({ code }) => ( + + navigator.clipboard.writeText(code)} + sx={{ + position: 'absolute', + top: 8, + right: 8, + color: 'grey.500', + '&:hover': { color: 'grey.300' }, + }} + > + + + + {code} + + +); + +export default CodeBlock; \ No newline at end of file diff --git a/frontend/src/components/shared/FilterTag.js b/frontend/src/components/shared/FilterTag.js new file mode 100644 index 0000000000000000000000000000000000000000..3cd154cb61a699bf94a2af0ba78286e3588aa754 --- /dev/null +++ b/frontend/src/components/shared/FilterTag.js @@ -0,0 +1,139 @@ +import React from "react"; +import { Chip } from "@mui/material"; +import { useTheme } from "@mui/material/styles"; +import { alpha } from "@mui/material/styles"; +import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank"; +import CheckBoxOutlinedIcon from "@mui/icons-material/CheckBoxOutlined"; + +const FilterTag = ({ + label, + checked, + onChange, + count, + isHideFilter = false, + totalCount = 0, + variant = "tag", + showCheckbox = false, + stacked = false, + sx = {}, +}) => { + const theme = useTheme(); + + const formatCount = (count) => { + if (count === undefined) return ""; + return `${count}`; + }; + + const mainLabel = label; + const countLabel = count !== undefined ? formatCount(count) : ""; + + return ( + + ) : ( + + ) + ) : null + } + label={ + + {mainLabel} + {countLabel && ( + <> + + {countLabel} + + )} + + } + onClick={onChange} + variant="outlined" + color={ + checked + ? variant === "secondary" + ? "secondary" + : "primary" + : "default" + } + size="small" + data-checked={checked} + sx={{ + height: "32px", + fontWeight: 600, + opacity: checked ? 1 : 0.8, + borderRadius: "5px", + borderWidth: "1px", + borderStyle: "solid", + cursor: "pointer", + pl: showCheckbox ? 0.5 : 0, + mr: 0.5, + mb: 0.5, + transition: "opacity 0.2s ease, border-color 0.2s ease", + "& .MuiChip-label": { + px: 0.75, + pl: showCheckbox ? 0.6 : 0.75, + }, + "& .MuiChip-icon": { + mr: 0.5, + pl: 0.2, + }, + "&:hover": { + opacity: 1, + backgroundColor: checked + ? alpha( + theme.palette[variant === "secondary" ? "secondary" : "primary"] + .main, + theme.palette.mode === "light" ? 0.08 : 0.16 + ) + : "action.hover", + borderWidth: "1px", + }, + backgroundColor: checked + ? alpha( + theme.palette[variant === "secondary" ? "secondary" : "primary"] + .main, + theme.palette.mode === "light" ? 0.08 : 0.16 + ) + : "background.paper", + borderColor: checked + ? variant === "secondary" + ? "secondary.main" + : "primary.main" + : "divider", + ...sx, + }} + /> + ); +}; + +export default FilterTag; diff --git a/frontend/src/components/shared/InfoIconWithTooltip.js b/frontend/src/components/shared/InfoIconWithTooltip.js new file mode 100644 index 0000000000000000000000000000000000000000..2b307ccaf8d7bebb91c81b2ff7cf746a4fbac05e --- /dev/null +++ b/frontend/src/components/shared/InfoIconWithTooltip.js @@ -0,0 +1,87 @@ +import React from "react"; +import { Box, Tooltip, Portal, Backdrop } from "@mui/material"; +import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; + +const InfoIconWithTooltip = ({ tooltip, iconProps = {}, sx = {} }) => { + const [open, setOpen] = React.useState(false); + + return ( + <> + setOpen(true)} + onClose={() => setOpen(false)} + componentsProps={{ + tooltip: { + sx: { + bgcolor: "rgba(33, 33, 33, 0.95)", + padding: "12px 16px", + maxWidth: "none !important", + width: "auto", + minWidth: "200px", + fontSize: "0.875rem", + lineHeight: 1.5, + position: "relative", + zIndex: 1501, + "& .MuiTooltip-arrow": { + color: "rgba(33, 33, 33, 0.95)", + }, + }, + }, + popper: { + sx: { + zIndex: 1501, + maxWidth: "min(600px, 90vw) !important", + '&[data-popper-placement*="bottom"] .MuiTooltip-tooltip': { + marginTop: "10px", + }, + '&[data-popper-placement*="top"] .MuiTooltip-tooltip': { + marginBottom: "10px", + }, + }, + }, + }} + > + + + + + {open && ( + + + + )} + + ); +}; + +export default InfoIconWithTooltip; diff --git a/frontend/src/components/shared/PageHeader.js b/frontend/src/components/shared/PageHeader.js new file mode 100644 index 0000000000000000000000000000000000000000..4e3e255933e84a6c4e2354eff643277ee0256017 --- /dev/null +++ b/frontend/src/components/shared/PageHeader.js @@ -0,0 +1,29 @@ +import React from "react"; +import { Box, Typography } from "@mui/material"; + +const PageHeader = ({ title, subtitle }) => { + return ( + + + {title} + + {subtitle && ( + + {subtitle} + + )} + + ); +}; + +export default PageHeader; diff --git a/frontend/src/config/auth.js b/frontend/src/config/auth.js new file mode 100644 index 0000000000000000000000000000000000000000..250e7b0a8de7128983ac3e5f36f9fd1f82046122 --- /dev/null +++ b/frontend/src/config/auth.js @@ -0,0 +1,7 @@ +export const HF_CONFIG = { + CLIENT_ID: "18fe6b93-6921-444c-9a20-5c22c578f2d8", + STORAGE_KEY: "hf_oauth", + SCOPE: "openid profile", + PROD_URL: "https://open-llm-leaderboard-open-llm-leaderboard.hf.space", + DEV_URL: "http://localhost:7860" +}; \ No newline at end of file diff --git a/frontend/src/config/theme.js b/frontend/src/config/theme.js new file mode 100644 index 0000000000000000000000000000000000000000..4bd6e4ae0ac0810a89f7aafb480b3b12fbe0f524 --- /dev/null +++ b/frontend/src/config/theme.js @@ -0,0 +1,390 @@ +import { createTheme, alpha } from "@mui/material/styles"; + +const getDesignTokens = (mode) => ({ + typography: { + fontFamily: [ + "-apple-system", + "BlinkMacSystemFont", + '"Segoe UI"', + "Roboto", + '"Helvetica Neue"', + "Arial", + "sans-serif", + ].join(","), + h1: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h2: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h3: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h4: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h5: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h6: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + subtitle1: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + subtitle2: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + }, + palette: { + mode, + primary: { + main: "#4F86C6", + light: mode === "light" ? "#7BA7D7" : "#6B97D7", + dark: mode === "light" ? "#2B5C94" : "#3B6CA4", + 50: mode === "light" ? alpha("#4F86C6", 0.05) : alpha("#4F86C6", 0.15), + 100: mode === "light" ? alpha("#4F86C6", 0.1) : alpha("#4F86C6", 0.2), + 200: mode === "light" ? alpha("#4F86C6", 0.2) : alpha("#4F86C6", 0.3), + contrastText: "#fff", + }, + background: { + default: mode === "light" ? "#f8f9fa" : "#0a0a0a", + paper: mode === "light" ? "#fff" : "#1a1a1a", + subtle: mode === "light" ? "grey.100" : "grey.900", + hover: mode === "light" ? "action.hover" : alpha("#fff", 0.08), + tooltip: mode === "light" ? alpha("#212121", 0.9) : alpha("#fff", 0.9), + }, + text: { + primary: mode === "light" ? "rgba(0, 0, 0, 0.87)" : "#fff", + secondary: + mode === "light" ? "rgba(0, 0, 0, 0.6)" : "rgba(255, 255, 255, 0.7)", + disabled: + mode === "light" ? "rgba(0, 0, 0, 0.38)" : "rgba(255, 255, 255, 0.5)", + hint: + mode === "light" ? "rgba(0, 0, 0, 0.38)" : "rgba(255, 255, 255, 0.5)", + }, + divider: + mode === "light" ? "rgba(0, 0, 0, 0.12)" : "rgba(255, 255, 255, 0.12)", + action: { + active: + mode === "light" ? "rgba(0, 0, 0, 0.54)" : "rgba(255, 255, 255, 0.7)", + hover: + mode === "light" ? "rgba(0, 0, 0, 0.04)" : "rgba(255, 255, 255, 0.08)", + selected: + mode === "light" ? "rgba(0, 0, 0, 0.08)" : "rgba(255, 255, 255, 0.16)", + disabled: + mode === "light" ? "rgba(0, 0, 0, 0.26)" : "rgba(255, 255, 255, 0.3)", + disabledBackground: + mode === "light" ? "rgba(0, 0, 0, 0.12)" : "rgba(255, 255, 255, 0.12)", + }, + }, + shape: { + borderRadius: 8, + }, + components: { + MuiCssBaseline: { + styleOverrides: { + "html, body": { + backgroundColor: "background.default", + color: mode === "dark" ? "#fff" : "#000", + }, + body: { + "& *::-webkit-scrollbar": { + width: 8, + height: 8, + backgroundColor: "transparent", + }, + "& *::-webkit-scrollbar-thumb": { + borderRadius: 8, + backgroundColor: + mode === "light" ? alpha("#000", 0.2) : alpha("#fff", 0.1), + "&:hover": { + backgroundColor: + mode === "light" ? alpha("#000", 0.3) : alpha("#fff", 0.15), + }, + }, + }, + }, + }, + MuiButton: { + styleOverrides: { + root: { + borderRadius: 8, + }, + }, + }, + MuiPaper: { + defaultProps: { + elevation: 0, + }, + styleOverrides: { + root: { + backgroundImage: "none", + boxShadow: "none", + border: "1px solid", + borderColor: + mode === "light" + ? "rgba(0, 0, 0, 0.12)!important" + : "rgba(255, 255, 255, 0.25)!important", + }, + rounded: { + borderRadius: 12, + }, + }, + }, + + MuiTableCell: { + styleOverrides: { + root: { + borderColor: (theme) => + alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + ), + }, + head: { + backgroundColor: mode === "light" ? "grey.50" : "grey.900", + color: "text.primary", + fontWeight: 600, + }, + }, + }, + MuiTableRow: { + styleOverrides: { + root: { + backgroundColor: "transparent", + }, + }, + }, + MuiTableContainer: { + styleOverrides: { + root: { + backgroundColor: "background.paper", + borderRadius: 8, + border: "none", + boxShadow: "none", + }, + }, + }, + MuiSlider: { + styleOverrides: { + root: { + "& .MuiSlider-valueLabel": { + backgroundColor: "background.paper", + color: "text.primary", + border: "1px solid", + borderColor: "divider", + boxShadow: + mode === "light" + ? "0px 2px 4px rgba(0, 0, 0, 0.1)" + : "0px 2px 4px rgba(0, 0, 0, 0.3)", + }, + }, + thumb: { + "&:hover": { + boxShadow: (theme) => + `0px 0px 0px 8px ${alpha( + theme.palette.primary.main, + mode === "light" ? 0.08 : 0.16 + )}`, + }, + "&.Mui-active": { + boxShadow: (theme) => + `0px 0px 0px 12px ${alpha( + theme.palette.primary.main, + mode === "light" ? 0.08 : 0.16 + )}`, + }, + }, + track: { + border: "none", + }, + rail: { + opacity: mode === "light" ? 0.38 : 0.3, + }, + mark: { + backgroundColor: mode === "light" ? "grey.400" : "grey.600", + }, + markLabel: { + color: "text.secondary", + }, + }, + }, + MuiTextField: { + styleOverrides: { + root: { + "& .MuiOutlinedInput-root": { + borderRadius: 8, + }, + }, + }, + }, + MuiChip: { + styleOverrides: { + root: { + borderRadius: 8, + }, + outlinedInfo: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "info.100", + borderColor: "info.400", + color: "info.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "info.200", + }, + }, + outlinedWarning: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "warning.100", + borderColor: "warning.400", + color: "warning.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "warning.200", + }, + }, + outlinedSuccess: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "success.100", + borderColor: "success.400", + color: "success.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "success.200", + }, + }, + outlinedError: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "error.100", + borderColor: "error.400", + color: "error.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "error.200", + }, + }, + outlinedPrimary: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "primary.100", + borderColor: "primary.400", + color: "primary.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "primary.200", + }, + }, + outlinedSecondary: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "secondary.100", + borderColor: "secondary.400", + color: "secondary.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "secondary.200", + }, + }, + }, + }, + MuiIconButton: { + styleOverrides: { + root: { + borderRadius: 8, + padding: "8px", + "&.MuiIconButton-sizeSmall": { + padding: "4px", + borderRadius: 6, + }, + }, + }, + }, + MuiTooltip: { + styleOverrides: { + tooltip: { + backgroundColor: + mode === "light" ? alpha("#212121", 0.9) : alpha("#424242", 0.9), + color: "#fff", + fontSize: "0.875rem", + padding: "8px 12px", + maxWidth: 400, + borderRadius: 8, + lineHeight: 1.4, + border: "1px solid", + borderColor: + mode === "light" ? alpha("#fff", 0.1) : alpha("#fff", 0.05), + boxShadow: + mode === "light" + ? "0 2px 8px rgba(0, 0, 0, 0.15)" + : "0 2px 8px rgba(0, 0, 0, 0.5)", + "& b": { + fontWeight: 600, + color: "inherit", + }, + "& a": { + color: mode === "light" ? "#90caf9" : "#64b5f6", + textDecoration: "none", + "&:hover": { + textDecoration: "underline", + }, + }, + }, + arrow: { + color: + mode === "light" ? alpha("#212121", 0.9) : alpha("#424242", 0.9), + "&:before": { + border: "1px solid", + borderColor: + mode === "light" ? alpha("#fff", 0.1) : alpha("#fff", 0.05), + }, + }, + }, + defaultProps: { + arrow: true, + enterDelay: 400, + leaveDelay: 200, + }, + }, + MuiAppBar: { + styleOverrides: { + root: { + border: "none", + borderBottom: "none", + }, + }, + }, + }, + breakpoints: { + values: { + xs: 0, + sm: 600, + md: 900, + lg: 1240, + xl: 1536, + }, + }, +}); + +const getTheme = (mode) => { + const tokens = getDesignTokens(mode); + return createTheme(tokens); +}; + +export default getTheme; diff --git a/frontend/src/hooks/useAuth.js b/frontend/src/hooks/useAuth.js new file mode 100644 index 0000000000000000000000000000000000000000..166d61aaaea425b8ec6e0c1d6bcf16311a94f369 --- /dev/null +++ b/frontend/src/hooks/useAuth.js @@ -0,0 +1,173 @@ +import { useState, useEffect } from "react"; +import { useLocation, useNavigate } from "react-router-dom"; +import { oauthLoginUrl, oauthHandleRedirectIfPresent } from "@huggingface/hub"; +import { HF_CONFIG } from "../config/auth"; + +async function fetchUserInfo(token) { + const response = await fetch("https://huggingface.co/api/whoami-v2", { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + if (!response.ok) { + throw new Error("Failed to fetch user info"); + } + return response.json(); +} + +export function useAuth() { + const [isAuthenticated, setIsAuthenticated] = useState(false); + const [user, setUser] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const location = useLocation(); + const navigate = useNavigate(); + + // Initialisation de l'authentification + useEffect(() => { + let mounted = true; + const initAuth = async () => { + try { + console.group("Auth Initialization"); + setLoading(true); + + // Vérifier s'il y a une redirection OAuth d'abord + let oauthResult = await oauthHandleRedirectIfPresent(); + + // Si pas de redirection, vérifier le localStorage + if (!oauthResult) { + const storedAuth = localStorage.getItem(HF_CONFIG.STORAGE_KEY); + if (storedAuth) { + try { + oauthResult = JSON.parse(storedAuth); + console.log("Found existing auth"); + const userInfo = await fetchUserInfo(oauthResult.access_token); + if (mounted) { + setIsAuthenticated(true); + setUser({ + username: userInfo.name, + token: oauthResult.access_token, + }); + } + } catch (err) { + console.log("Invalid stored auth data, clearing...", err); + localStorage.removeItem(HF_CONFIG.STORAGE_KEY); + if (mounted) { + setIsAuthenticated(false); + setUser(null); + } + } + } + } else { + console.log("Processing OAuth redirect"); + const token = oauthResult.accessToken; + const userInfo = await fetchUserInfo(token); + + const authData = { + access_token: token, + username: userInfo.name, + }; + + localStorage.setItem(HF_CONFIG.STORAGE_KEY, JSON.stringify(authData)); + + if (mounted) { + setIsAuthenticated(true); + setUser({ + username: userInfo.name, + token: token, + }); + } + + // Rediriger vers la page d'origine + const returnTo = localStorage.getItem("auth_return_to"); + if (returnTo) { + navigate(returnTo); + localStorage.removeItem("auth_return_to"); + } + } + } catch (err) { + console.error("Auth initialization error:", err); + if (mounted) { + setError(err.message); + setIsAuthenticated(false); + setUser(null); + } + } finally { + if (mounted) { + setLoading(false); + } + console.groupEnd(); + } + }; + + initAuth(); + + return () => { + mounted = false; + }; + }, [navigate, location.pathname]); + + const login = async () => { + try { + console.group("Login Process"); + setLoading(true); + + // Sauvegarder la route actuelle pour la redirection post-auth + const currentRoute = window.location.hash.replace("#", "") || "/"; + localStorage.setItem("auth_return_to", currentRoute); + + // Déterminer l'URL de redirection en fonction de l'environnement + const redirectUrl = + window.location.hostname === "localhost" || + window.location.hostname === "127.0.0.1" + ? HF_CONFIG.DEV_URL + : HF_CONFIG.PROD_URL; + + console.log("Using redirect URL:", redirectUrl); + + // Générer l'URL de login et rediriger + const loginUrl = await oauthLoginUrl({ + clientId: HF_CONFIG.CLIENT_ID, + redirectUrl, + scope: HF_CONFIG.SCOPE, + }); + + window.location.href = loginUrl + "&prompt=consent"; + + console.groupEnd(); + } catch (err) { + console.error("Login error:", err); + setError(err.message); + setLoading(false); + console.groupEnd(); + } + }; + + const logout = () => { + console.group("Logout Process"); + setLoading(true); + try { + console.log("Clearing auth data..."); + localStorage.removeItem(HF_CONFIG.STORAGE_KEY); + localStorage.removeItem("auth_return_to"); + setIsAuthenticated(false); + setUser(null); + console.log("Logged out successfully"); + } catch (err) { + console.error("Logout error:", err); + setError(err.message); + } finally { + setLoading(false); + console.groupEnd(); + } + }; + + return { + isAuthenticated, + user, + loading, + error, + login, + logout, + }; +} diff --git a/frontend/src/hooks/useThemeMode.js b/frontend/src/hooks/useThemeMode.js new file mode 100644 index 0000000000000000000000000000000000000000..93030109e2b32281c05178cc4207cb5544e94e4f --- /dev/null +++ b/frontend/src/hooks/useThemeMode.js @@ -0,0 +1,28 @@ +import { useState, useEffect } from 'react'; + +export const useThemeMode = () => { + // Get system preference + const getSystemPreference = () => { + return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light'; + }; + + // Initialize theme mode from system preference + const [mode, setMode] = useState(getSystemPreference); + + // Listen to system preference changes + useEffect(() => { + const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)'); + const handleChange = (e) => { + setMode(e.matches ? 'dark' : 'light'); + }; + + mediaQuery.addEventListener('change', handleChange); + return () => mediaQuery.removeEventListener('change', handleChange); + }, []); + + const toggleTheme = () => { + setMode((prevMode) => (prevMode === 'light' ? 'dark' : 'light')); + }; + + return { mode, toggleTheme }; +}; \ No newline at end of file diff --git a/frontend/src/index.js b/frontend/src/index.js new file mode 100644 index 0000000000000000000000000000000000000000..8db5acb8fb94a08138a3901be0b5b810c9e50931 --- /dev/null +++ b/frontend/src/index.js @@ -0,0 +1,10 @@ +import React from "react"; +import ReactDOM from "react-dom/client"; +import App from "./App"; + +const root = ReactDOM.createRoot(document.getElementById("root")); +root.render( + + + +); diff --git a/frontend/src/pages/AddModelPage/AddModelPage.js b/frontend/src/pages/AddModelPage/AddModelPage.js new file mode 100644 index 0000000000000000000000000000000000000000..eae5f6da4ae95355664c4227b9d1040471a4ba3e --- /dev/null +++ b/frontend/src/pages/AddModelPage/AddModelPage.js @@ -0,0 +1,51 @@ +import React from "react"; +import { Box, CircularProgress } from "@mui/material"; +import { useAuth } from "../../hooks/useAuth"; +import PageHeader from "../../components/shared/PageHeader"; +import EvaluationQueues from "./components/EvaluationQueues/EvaluationQueues"; +import ModelSubmissionForm from "./components/ModelSubmissionForm/ModelSubmissionForm"; +import SubmissionGuide from "./components/SubmissionGuide/SubmissionGuide"; +import SubmissionLimitChecker from "./components/SubmissionLimitChecker/SubmissionLimitChecker"; + +function AddModelPage() { + const { isAuthenticated, loading, user } = useAuth(); + + if (loading) { + return ( + + + + ); + } + + return ( + + + Add your model to the Open + LLM Leaderboard + + } + /> + + + + + + + + + + ); +} + +export default AddModelPage; diff --git a/frontend/src/pages/AddModelPage/components/EvaluationQueues/EvaluationQueues.js b/frontend/src/pages/AddModelPage/components/EvaluationQueues/EvaluationQueues.js new file mode 100644 index 0000000000000000000000000000000000000000..c0b071d814c4fb7d8d1567221a086cf1396ff4ca --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/EvaluationQueues/EvaluationQueues.js @@ -0,0 +1,787 @@ +import React, { useState, useEffect, useRef } from "react"; +import { + Box, + Typography, + Table, + TableBody, + TableCell, + TableContainer, + TableHead, + TableRow, + Chip, + Link, + CircularProgress, + Alert, + Accordion, + AccordionSummary, + AccordionDetails, + Stack, + Tooltip, + useTheme, + useMediaQuery, +} from "@mui/material"; +import AccessTimeIcon from "@mui/icons-material/AccessTime"; +import CheckCircleIcon from "@mui/icons-material/CheckCircle"; +import PendingIcon from "@mui/icons-material/Pending"; +import AutorenewIcon from "@mui/icons-material/Autorenew"; +import ExpandMoreIcon from "@mui/icons-material/ExpandMore"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import { useVirtualizer } from "@tanstack/react-virtual"; + +// Function to format wait time +const formatWaitTime = (waitTimeStr) => { + const seconds = parseFloat(waitTimeStr.replace("s", "")); + + if (seconds < 60) { + return "just now"; + } + + const minutes = Math.floor(seconds / 60); + if (minutes < 60) { + return `${minutes}m ago`; + } + + const hours = Math.floor(minutes / 60); + if (hours < 24) { + return `${hours}h ago`; + } + + const days = Math.floor(hours / 24); + return `${days}d ago`; +}; + +// Column definitions with their properties +const columns = [ + { + id: "model", + label: "Model", + width: "35%", + align: "left", + }, + { + id: "submitter", + label: "Submitted by", + width: "15%", + align: "left", + }, + { + id: "wait_time", + label: "Submitted", + width: "12%", + align: "center", + }, + { + id: "precision", + label: "Precision", + width: "13%", + align: "center", + }, + { + id: "revision", + label: "Revision", + width: "12%", + align: "center", + }, + { + id: "status", + label: "Status", + width: "13%", + align: "center", + }, +]; + +const StatusChip = ({ status }) => { + const statusConfig = { + finished: { + icon: , + label: "Completed", + color: "success", + }, + evaluating: { + icon: , + label: "Evaluating", + color: "warning", + }, + pending: { icon: , label: "Pending", color: "info" }, + }; + + const config = statusConfig[status] || statusConfig.pending; + + return ( + + ); +}; + +const ModelTable = ({ models, emptyMessage, status }) => { + const parentRef = useRef(null); + const rowVirtualizer = useVirtualizer({ + count: models.length, + getScrollElement: () => parentRef.current, + estimateSize: () => 53, + overscan: 5, + }); + + if (models.length === 0) { + return ( + + {emptyMessage} + + ); + } + + return ( + + + + {columns.map((column) => ( + + ))} + + + + {columns.map((column, index) => ( + + {column.label} + + ))} + + + + + + <> + {rowVirtualizer.getVirtualItems().map((virtualRow) => { + const model = models[virtualRow.index]; + const waitTime = formatWaitTime(model.wait_time); + + return ( + + + + {model.name} + + + + + {model.submitter} + + + + + + {waitTime} + + + + + + {model.precision} + + + + {model.revision.substring(0, 7)} + + + + + + ); + })} + + + + +
+
+ ); +}; + +const QueueAccordion = ({ + title, + models, + status, + emptyMessage, + expanded, + onChange, + loading, +}) => { + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + return ( + + } + sx={{ + px: { xs: 2, sm: 3 }, + py: { xs: 1.5, sm: 2 }, + alignItems: { xs: "flex-start", sm: "center" }, + "& .MuiAccordionSummary-expandIconWrapper": { + marginTop: { xs: "4px", sm: 0 }, + }, + }} + > + + + {title} + + + ({ + borderWidth: 2, + fontWeight: 600, + fontSize: { xs: "0.75rem", sm: "0.875rem" }, + height: { xs: "24px", sm: "32px" }, + width: { xs: "100%", sm: "auto" }, + bgcolor: + status === "finished" + ? theme.palette.success[100] + : status === "evaluating" + ? theme.palette.warning[100] + : theme.palette.info[100], + borderColor: + status === "finished" + ? theme.palette.success[400] + : status === "evaluating" + ? theme.palette.warning[400] + : theme.palette.info[400], + color: + status === "finished" + ? theme.palette.success[700] + : status === "evaluating" + ? theme.palette.warning[700] + : theme.palette.info[700], + "& .MuiChip-label": { + px: { xs: 1, sm: 1.2 }, + width: "100%", + }, + "&:hover": { + bgcolor: + status === "finished" + ? theme.palette.success[200] + : status === "evaluating" + ? theme.palette.warning[200] + : theme.palette.info[200], + }, + })} + /> + {loading && ( + + )} + + + + + + + + + + ); +}; + +const EvaluationQueues = ({ defaultExpanded = true }) => { + const [expanded, setExpanded] = useState(defaultExpanded); + const [expandedQueues, setExpandedQueues] = useState(new Set()); + const [models, setModels] = useState({ + pending: [], + evaluating: [], + finished: [], + }); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + useEffect(() => { + const fetchModels = async () => { + try { + const response = await fetch("/api/models/status"); + if (!response.ok) { + throw new Error("Failed to fetch models"); + } + const data = await response.json(); + + // Sort models by submission date (most recent first) + const sortByDate = (models) => { + return [...models].sort((a, b) => { + const dateA = new Date(a.submission_time); + const dateB = new Date(b.submission_time); + return dateB - dateA; + }); + }; + + setModels({ + finished: sortByDate(data.finished), + evaluating: sortByDate(data.evaluating), + pending: sortByDate(data.pending), + }); + } catch (err) { + setError(err.message); + } finally { + setLoading(false); + } + }; + + fetchModels(); + const interval = setInterval(fetchModels, 30000); + return () => clearInterval(interval); + }, []); + + const handleMainAccordionChange = (panel) => (event, isExpanded) => { + setExpanded(isExpanded ? panel : false); + }; + + const handleQueueAccordionChange = (queueName) => (event, isExpanded) => { + setExpandedQueues((prev) => { + const newSet = new Set(prev); + if (isExpanded) { + newSet.add(queueName); + } else { + newSet.delete(queueName); + } + return newSet; + }); + }; + + if (error) { + return ( + + {error} + + ); + } + + return ( + + } + sx={{ + px: { xs: 2, sm: 3 }, + "& .MuiAccordionSummary-expandIconWrapper": { + color: "text.secondary", + transform: "rotate(0deg)", + transition: "transform 150ms", + marginTop: { xs: "4px", sm: 0 }, + "&.Mui-expanded": { + transform: "rotate(180deg)", + }, + }, + }} + > + + + Evaluation Status + + {!loading && ( + + + + + + )} + {loading && ( + + )} + + + + {loading ? ( + + + + ) : ( + <> + + + + + + + )} + + + ); +}; + +export default EvaluationQueues; diff --git a/frontend/src/pages/AddModelPage/components/ModelSubmissionForm/ModelSubmissionForm.js b/frontend/src/pages/AddModelPage/components/ModelSubmissionForm/ModelSubmissionForm.js new file mode 100644 index 0000000000000000000000000000000000000000..31425f5e65d9c5e87709dc234464bbfb061448b9 --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/ModelSubmissionForm/ModelSubmissionForm.js @@ -0,0 +1,599 @@ +import React, { useState } from "react"; +import { + Box, + Paper, + Typography, + TextField, + Button, + FormControl, + InputLabel, + Select, + MenuItem, + FormControlLabel, + Switch, + Stack, + Grid, + CircularProgress, + Alert, +} from "@mui/material"; +import RocketLaunchIcon from "@mui/icons-material/RocketLaunch"; +import CheckCircleOutlineIcon from "@mui/icons-material/CheckCircleOutline"; +import { alpha } from "@mui/material/styles"; +import InfoIconWithTooltip from "../../../../components/shared/InfoIconWithTooltip"; +import { MODEL_TYPES } from "../../../../pages/LeaderboardPage/components/Leaderboard/constants/modelTypes"; +import { SUBMISSION_PRECISIONS } from "../../../../pages/LeaderboardPage/components/Leaderboard/constants/defaults"; +import AuthContainer from "../../../../components/shared/AuthContainer"; + +const WEIGHT_TYPES = [ + { value: "Original", label: "Original" }, + { value: "Delta", label: "Delta" }, + { value: "Adapter", label: "Adapter" }, +]; + +const HELP_TEXTS = { + modelName: ( + + + Model Name on Hugging Face Hub + + + Your model must be public and loadable with AutoClasses without + trust_remote_code. The model should be in Safetensors format for better + safety and loading performance. Example: mistralai/Mistral-7B-v0.1 + + + ), + revision: ( + + + Model Revision + + + Git branch, tag or commit hash. The evaluation will be strictly tied to + this specific commit to ensure consistency. Make sure this version is + stable and contains all necessary files. + + + ), + modelType: ( + + + Model Category + + + 🟢 Pretrained: Base models trained on text using masked modeling 🟩 + Continuously Pretrained: Extended training on additional corpus 🔶 + Fine-tuned: Domain-specific optimization 💬 Chat: Models using RLHF, + DPO, or IFT for conversation 🤝 Merge: Combined weights without + additional training 🌸 Multimodal: Handles multiple input types + + + ), + baseModel: ( + + + Base Model Reference + + + Required for delta weights or adapters. This information is used to + identify the original model and calculate the total parameter count by + combining base model and adapter/delta parameters. + + + ), + precision: ( + + + Model Precision + + + Size limits vary by precision: • FP16/BF16: up to 100B parameters • + 8-bit: up to 280B parameters (2x) • 4-bit: up to 560B parameters (4x) + Choose carefully as incorrect precision can cause evaluation errors. + + + ), + weightsType: ( + + + Weights Format + + + Original: Complete model weights in safetensors format Delta: Weight + differences from base model (requires base model for size calculation) + Adapter: Lightweight fine-tuning layers (requires base model for size + calculation) + + + ), + chatTemplate: ( + + + Chat Template Support + + + Activates automatically for chat models. It uses the standardized + Hugging Face chat template for consistent prompt formatting during + evaluation. Required for models using RLHF, DPO, or instruction + fine-tuning. + + + ), +}; + +// Convert MODEL_TYPES to format expected by Select component +const modelTypeOptions = Object.entries(MODEL_TYPES).map( + ([value, { icon, label }]) => ({ + value, + label: `${icon} ${label}`, + }) +); + +function ModelSubmissionForm({ user, isAuthenticated }) { + const [formData, setFormData] = useState({ + modelName: "", + revision: "main", + modelType: "fine-tuned", + isChatModel: false, + useChatTemplate: false, + precision: "float16", + weightsType: "Original", + baseModel: "", + }); + const [error, setError] = useState(null); + const [submitting, setSubmitting] = useState(false); + const [success, setSuccess] = useState(false); + const [submittedData, setSubmittedData] = useState(null); + + const handleChange = (event) => { + const { name, value, checked } = event.target; + setFormData((prev) => ({ + ...prev, + [name]: event.target.type === "checkbox" ? checked : value, + })); + }; + + const handleSubmit = async (e) => { + e.preventDefault(); + setError(null); + setSubmitting(true); + + try { + const response = await fetch("/api/models/submit", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model_id: formData.modelName, + revision: formData.revision, + model_type: formData.modelType, + precision: formData.precision, + weight_type: formData.weightsType, + base_model: formData.baseModel, + use_chat_template: formData.useChatTemplate, + user_id: user.username, + }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to submit model"); + } + + setSubmittedData(formData); + setSuccess(true); + } catch (error) { + setError(error.message); + } finally { + setSubmitting(false); + } + }; + + if (success && submittedData) { + return ( + ({ + p: 6, + mb: 3, + bgcolor: alpha(theme.palette.success.main, 0.05), + borderColor: alpha(theme.palette.success.main, 0.2), + })} + > + + + + + Model submitted successfully! + + + + + Your model {submittedData.modelName} has been added + to the evaluation queue with the following parameters: + + + + + + + Model: + + + {submittedData.modelName} + + + + + Type: + + + {submittedData.modelType} + + + + + Revision: + + + {submittedData.revision} + + + + + Precision: + + + {submittedData.precision} + + + + + Weight type: + + + {submittedData.weightsType} + + + {submittedData.baseModel && ( + + + Base model: + + + {submittedData.baseModel} + + + )} + + + Chat template: + + + {submittedData.useChatTemplate ? "Yes" : "No"} + + + + + + + An automatic upvote has been added to your model to help with + prioritization. + + + + + + + + ); + } + + return ( + <> + {error && ( + + {error} + + )} + + {isAuthenticated && ( + + {/* Header */} + + theme.palette.mode === "dark" + ? alpha(theme.palette.divider, 0.1) + : "grey.200", + bgcolor: (theme) => + theme.palette.mode === "dark" + ? alpha(theme.palette.background.paper, 0.5) + : "grey.50", + }} + > + + Model Submission Form + + + + {/* Form Content */} + + + {/* Model Information */} + + + Model Information + + + + + + + ), + }} + /> + + + + + ), + }} + /> + + + {/* Model Configuration */} + + + Model Configuration + + + + + + Model Type + + + + + + + + } + label="Use Chat Template" + /> + + + + + + + Precision + + + + + + + Weights Type + + + + + {formData.weightsType !== "Original" && ( + + + ), + }} + /> + + )} + + {/* Submit Button */} + + + + All fields marked with * are required + + + + + + + + )} + + ); +} + +export default ModelSubmissionForm; diff --git a/frontend/src/pages/AddModelPage/components/SubmissionGuide/SubmissionGuide.js b/frontend/src/pages/AddModelPage/components/SubmissionGuide/SubmissionGuide.js new file mode 100644 index 0000000000000000000000000000000000000000..c023f0ba51929d7bc854df404d5766bba2d5a8ee --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/SubmissionGuide/SubmissionGuide.js @@ -0,0 +1,274 @@ +import React, { useState, useEffect } from "react"; +import { useLocation, useNavigate } from "react-router-dom"; +import { Box, Paper, Typography, Button, Stack, Collapse } from "@mui/material"; +import ExpandMoreIcon from "@mui/icons-material/ExpandMore"; + +const DocLink = ({ href, children }) => ( + +); + +const StepNumber = ({ number }) => ( + + {number} + +); + +const TUTORIAL_STEPS = [ + { + title: "Model Information", + content: ( + + + Your model should be public on the Hub and follow the{" "} + username/model-id format (e.g. + mistralai/Mistral-7B-v0.1). Specify the revision{" "} + (commit hash or branch) and model type. + + + Model uploading guide + + + ), + }, + { + title: "Technical Details", + content: ( + + + Make sure your model can be loaded locally before + submitting: + + + theme.palette.mode === "dark" ? "grey.50" : "grey.900", + borderRadius: 1, + "& pre": { + m: 0, + p: 0, + fontFamily: "monospace", + fontSize: "0.875rem", + color: (theme) => + theme.palette.mode === "dark" ? "grey.900" : "grey.50", + }, + }} + > +
+            {`from transformers import AutoConfig, AutoModel, AutoTokenizer
+
+config = AutoConfig.from_pretrained("your-username/your-model", revision="main")
+model = AutoModel.from_pretrained("your-username/your-model", revision="main")
+tokenizer = AutoTokenizer.from_pretrained("your-username/your-model", revision="main")`}
+          
+
+ + Transformers documentation + +
+ ), + }, + { + title: "License Requirements", + content: ( + + + A license tag is required.{" "} + Open licenses (Apache, MIT, etc) are strongly + recommended. + + + About model licenses + + + ), + }, + { + title: "Model Card Requirements", + content: ( + + + Your model card must include: architecture,{" "} + training details,{" "} + dataset information, intended use, limitations, and{" "} + performance metrics. + + + Model cards guide + + + ), + }, + { + title: "Final Checklist", + content: ( + + + Ensure your model is public, uses{" "} + safetensors format, has a{" "} + license tag, and loads correctly{" "} + with the provided code. + + + Sharing best practices + + + ), + }, +]; + +function SubmissionGuide() { + const location = useLocation(); + const navigate = useNavigate(); + + // Initialize state directly with URL value + const initialExpanded = !new URLSearchParams(location.search).get("guide"); + const [expanded, setExpanded] = useState(initialExpanded); + + // Sync expanded state with URL changes after initial render + useEffect(() => { + const guideOpen = !new URLSearchParams(location.search).get("guide"); + if (guideOpen !== expanded) { + setExpanded(guideOpen); + } + }, [location.search, expanded]); + + const handleAccordionChange = () => { + const newExpanded = !expanded; + setExpanded(newExpanded); + const params = new URLSearchParams(location.search); + if (newExpanded) { + params.delete("guide"); + } else { + params.set("guide", "closed"); + } + navigate({ search: params.toString() }, { replace: true }); + }; + + return ( + + theme.palette.mode === "dark" ? "grey.800" : "grey.200", + overflow: "hidden", + }} + > + + theme.palette.mode === "dark" ? "grey.900" : "grey.50", + borderBottom: "1px solid", + borderColor: (theme) => + expanded + ? theme.palette.mode === "dark" + ? "grey.800" + : "grey.200" + : "transparent", + }} + > + + Submission Guide + + + + + + + {TUTORIAL_STEPS.map((step, index) => ( + + + + + + {step.title} + + + {step.content} + + {index < TUTORIAL_STEPS.length - 1 && ( + + theme.palette.mode === "dark" ? "grey.800" : "grey.100", + }} + /> + )} + + ))} + + + + + ); +} + +export default SubmissionGuide; diff --git a/frontend/src/pages/AddModelPage/components/SubmissionLimitChecker/SubmissionLimitChecker.js b/frontend/src/pages/AddModelPage/components/SubmissionLimitChecker/SubmissionLimitChecker.js new file mode 100644 index 0000000000000000000000000000000000000000..97f4a72884c5874e68a169ad3c9d6c1541c8852a --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/SubmissionLimitChecker/SubmissionLimitChecker.js @@ -0,0 +1,85 @@ +import React, { useState, useEffect } from "react"; +import { Alert, Box, CircularProgress } from "@mui/material"; + +const MAX_SUBMISSIONS_PER_WEEK = 10; + +function SubmissionLimitChecker({ user, children }) { + const [loading, setLoading] = useState(true); + const [reachedLimit, setReachedLimit] = useState(false); + const [error, setError] = useState(false); + + useEffect(() => { + const checkSubmissionLimit = async () => { + if (!user?.username) { + setLoading(false); + return; + } + + try { + const response = await fetch( + `/api/models/organization/${user.username}/submissions?days=7` + ); + if (!response.ok) { + throw new Error("Failed to fetch submission data"); + } + + const submissions = await response.json(); + console.log(`Recent submissions for ${user.username}:`, submissions); + setReachedLimit(submissions.length >= MAX_SUBMISSIONS_PER_WEEK); + setError(false); + } catch (error) { + console.error("Error checking submission limit:", error); + setError(true); + } finally { + setLoading(false); + } + }; + + checkSubmissionLimit(); + }, [user?.username]); + + if (loading) { + return ( + + + + ); + } + + if (error) { + return ( + + Unable to verify submission limits. Please try again in a few minutes. + + ); + } + + if (reachedLimit) { + return ( + + For fairness reasons, you cannot submit more than{" "} + {MAX_SUBMISSIONS_PER_WEEK} models per week. Please try again later. + + ); + } + + return children; +} + +export default SubmissionLimitChecker; diff --git a/frontend/src/pages/LeaderboardPage/LeaderboardPage.js b/frontend/src/pages/LeaderboardPage/LeaderboardPage.js new file mode 100644 index 0000000000000000000000000000000000000000..7394022a7082b38a9ad98e31318021c79512d936 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/LeaderboardPage.js @@ -0,0 +1,50 @@ +import { useEffect } from "react"; +import Leaderboard from "./components/Leaderboard/Leaderboard"; +import { Box } from "@mui/material"; +import PageHeader from "../../components/shared/PageHeader"; +import Logo from "../../components/Logo/Logo"; +import { useLeaderboardData } from "../../pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData"; +import { useLeaderboard } from "../../pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext"; + +function LeaderboardPage() { + const { data, isLoading, error } = useLeaderboardData(); + const { actions } = useLeaderboard(); + + useEffect(() => { + if (data) { + actions.setModels(data); + } + actions.setLoading(isLoading); + actions.setError(error); + }, [data, isLoading, error, actions]); + + return ( + + + + + + Comparing Large Language Models in an{" "} + open and{" "} + reproducible way + + } + /> + + + ); +} + +export default LeaderboardPage; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/Leaderboard.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/Leaderboard.js new file mode 100644 index 0000000000000000000000000000000000000000..5c41ce7fa5eeeb9b00bc657c174c9653c5d31503 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/Leaderboard.js @@ -0,0 +1,449 @@ +import React, { useMemo, useEffect, useCallback } from "react"; +import { Box, Typography } from "@mui/material"; +import { useSearchParams } from "react-router-dom"; + +import { TABLE_DEFAULTS } from "./constants/defaults"; +import { useLeaderboard } from "./context/LeaderboardContext"; +import { useLeaderboardProcessing } from "./hooks/useLeaderboardData"; +import { useLeaderboardData } from "./hooks/useLeaderboardData"; + +import LeaderboardFilters from "./components/Filters/Filters"; +import LeaderboardTable from "./components/Table/Table"; +import SearchBar, { SearchBarSkeleton } from "./components/Filters/SearchBar"; +import PerformanceMonitor from "./components/PerformanceMonitor"; +import QuickFilters, { + QuickFiltersSkeleton, +} from "./components/Filters/QuickFilters"; + +const FilterAccordion = ({ expanded, quickFilters, advancedFilters }) => { + const advancedFiltersRef = React.useRef(null); + const quickFiltersRef = React.useRef(null); + const [height, setHeight] = React.useState("auto"); + const resizeTimeoutRef = React.useRef(null); + + const updateHeight = React.useCallback(() => { + if (expanded && advancedFiltersRef.current) { + setHeight(`${advancedFiltersRef.current.scrollHeight}px`); + } else if (!expanded && quickFiltersRef.current) { + setHeight(`${quickFiltersRef.current.scrollHeight}px`); + } + }, [expanded]); + + React.useEffect(() => { + // Initial height calculation + const timer = setTimeout(updateHeight, 100); + + // Resize handler with debounce + const handleResize = () => { + if (resizeTimeoutRef.current) { + clearTimeout(resizeTimeoutRef.current); + } + resizeTimeoutRef.current = setTimeout(updateHeight, 150); + }; + + window.addEventListener("resize", handleResize); + + return () => { + clearTimeout(timer); + window.removeEventListener("resize", handleResize); + if (resizeTimeoutRef.current) { + clearTimeout(resizeTimeoutRef.current); + } + }; + }, [updateHeight]); + + // Update height when expanded state changes + React.useEffect(() => { + updateHeight(); + }, [expanded, updateHeight]); + + return ( + + + {quickFilters} + + + {advancedFilters} + + + ); +}; + +const Leaderboard = () => { + const { state, actions } = useLeaderboard(); + const [searchParams, setSearchParams] = useSearchParams(); + const { + data, + isLoading: dataLoading, + error: dataError, + } = useLeaderboardData(); + const { + table, + filteredData, + error: processingError, + } = useLeaderboardProcessing(); + + // Memoize filtered data + const memoizedFilteredData = useMemo(() => filteredData, [filteredData]); + const memoizedTable = useMemo(() => table, [table]); + + // Memoize table options + const hasTableOptionsChanges = useMemo(() => { + return ( + state.display.rowSize !== TABLE_DEFAULTS.ROW_SIZE || + JSON.stringify(state.display.scoreDisplay) !== + JSON.stringify(TABLE_DEFAULTS.SCORE_DISPLAY) || + state.display.averageMode !== TABLE_DEFAULTS.AVERAGE_MODE || + state.display.rankingMode !== TABLE_DEFAULTS.RANKING_MODE + ); + }, [state.display]); + + const hasColumnFilterChanges = useMemo(() => { + return ( + JSON.stringify([...state.display.visibleColumns].sort()) !== + JSON.stringify([...TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE].sort()) + ); + }, [state.display.visibleColumns]); + + // Memoize callbacks + const onToggleFilters = useCallback(() => { + actions.toggleFiltersExpanded(); + }, [actions]); + + const onColumnVisibilityChange = useCallback( + (newVisibility) => { + actions.setDisplayOption( + "visibleColumns", + Object.keys(newVisibility).filter((key) => newVisibility[key]) + ); + }, + [actions] + ); + + const onRowSizeChange = useCallback( + (size) => { + actions.setDisplayOption("rowSize", size); + }, + [actions] + ); + + const onScoreDisplayChange = useCallback( + (display) => { + actions.setDisplayOption("scoreDisplay", display); + }, + [actions] + ); + + const onAverageModeChange = useCallback( + (mode) => { + actions.setDisplayOption("averageMode", mode); + }, + [actions] + ); + + const onRankingModeChange = useCallback( + (mode) => { + actions.setDisplayOption("rankingMode", mode); + }, + [actions] + ); + + const onPrecisionsChange = useCallback( + (precisions) => { + actions.setFilter("precisions", precisions); + }, + [actions] + ); + + const onTypesChange = useCallback( + (types) => { + actions.setFilter("types", types); + }, + [actions] + ); + + const onParamsRangeChange = useCallback( + (range) => { + actions.setFilter("paramsRange", range); + }, + [actions] + ); + + const onBooleanFiltersChange = useCallback( + (filters) => { + actions.setFilter("booleanFilters", filters); + }, + [actions] + ); + + const onReset = useCallback(() => { + actions.resetFilters(); + }, [actions]); + + // Memoize loading states + const loadingStates = useMemo(() => { + const isInitialLoading = dataLoading || !data; + const isProcessingData = !memoizedTable || !memoizedFilteredData; + const isApplyingFilters = state.models.length > 0 && !memoizedFilteredData; + const hasValidFilterCounts = + state.countsReady && + state.filterCounts && + state.filterCounts.normal && + state.filterCounts.officialOnly; + + return { + isInitialLoading, + isProcessingData, + isApplyingFilters, + showSearchSkeleton: isInitialLoading || !hasValidFilterCounts, + showFiltersSkeleton: isInitialLoading || !hasValidFilterCounts, + showTableSkeleton: + isInitialLoading || + isProcessingData || + isApplyingFilters || + !hasValidFilterCounts, + }; + }, [ + dataLoading, + data, + memoizedTable, + memoizedFilteredData, + state.models.length, + state.filterCounts, + state.countsReady, + ]); + + // Memoize child components + const memoizedSearchBar = useMemo( + () => ( + + ), + [ + onToggleFilters, + state.filtersExpanded, + loadingStates.showTableSkeleton, + memoizedFilteredData, + table, + ] + ); + + const memoizedQuickFilters = useMemo( + () => ( + + ), + [state.models.length, memoizedFilteredData, memoizedTable] + ); + + const memoizedLeaderboardFilters = useMemo( + () => ( + + ), + [ + memoizedFilteredData, + loadingStates.showFiltersSkeleton, + state.filters.precisions, + state.filters.types, + state.filters.paramsRange, + state.filters.booleanFilters, + onPrecisionsChange, + onTypesChange, + onParamsRangeChange, + onBooleanFiltersChange, + onReset, + ] + ); + + // No need to memoize LeaderboardTable as it handles its own sorting state + const tableComponent = ( + + ); + + // Update context with loaded data + useEffect(() => { + if (data) { + actions.setModels(data); + } + }, [data, actions]); + + // Log to understand loading state + useEffect(() => { + if (process.env.NODE_ENV === "development") { + console.log("Loading state:", { + dataLoading, + hasData: !!data, + hasTable: !!table, + hasFilteredData: !!filteredData, + filteredDataLength: filteredData?.length, + stateModelsLength: state.models.length, + hasFilters: Object.keys(state.filters).some((key) => { + if (Array.isArray(state.filters[key])) { + return state.filters[key].length > 0; + } + return !!state.filters[key]; + }), + }); + } + }, [ + dataLoading, + data, + table, + filteredData?.length, + state.models.length, + filteredData, + state.filters, + ]); + + // If an error occurred, display it + if (dataError || processingError) { + return ( + + + {(dataError || processingError)?.message || + "An error occurred while loading the data"} + + + ); + } + + return ( + + + + + {loadingStates.showSearchSkeleton ? ( + + ) : ( + memoizedSearchBar + )} + + {loadingStates.showFiltersSkeleton ? ( + + ) : ( + + )} + + + + + + {tableComponent} + + + + + ); +}; + +export default Leaderboard; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/ColumnSelector/ColumnSelector.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/ColumnSelector/ColumnSelector.js new file mode 100644 index 0000000000000000000000000000000000000000..5a67cacd3d1d3343d22abcf7fd083440bcb94881 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/ColumnSelector/ColumnSelector.js @@ -0,0 +1,217 @@ +import React from "react"; +import { Box, Typography } from "@mui/material"; +import ViewColumnIcon from "@mui/icons-material/ViewColumn"; +import CloseIcon from "@mui/icons-material/Close"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; +import { TABLE_DEFAULTS } from "../../constants/defaults"; +import DropdownButton from "../shared/DropdownButton"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; + +const FilterGroup = ({ title, children, count, total }) => ( + + + {title} + {count !== undefined && total !== undefined && ( + + ({count}/{total}) + + )} + + + {children} + + +); + +const ColumnSelector = ({ + table, + onReset, + hasChanges, + onColumnVisibilityChange, + loading = false, +}) => { + const { getState, setColumnVisibility } = table; + const { columnVisibility } = getState(); + + // Filter columns to only show filterable ones + const filterableColumns = [ + ...TABLE_DEFAULTS.COLUMNS.EVALUATION, + ...TABLE_DEFAULTS.COLUMNS.OPTIONAL, + ]; + + const handleReset = (e) => { + e.preventDefault(); + e.stopPropagation(); + + if (!hasChanges) return; + + // Call onReset first + onReset?.(); + + // Create object with all columns set to false by default + const defaultVisibility = {}; + + // Set to true all columns that should be visible by default + TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE.forEach((col) => { + defaultVisibility[col] = true; + }); + + onColumnVisibilityChange?.(defaultVisibility); + setColumnVisibility(defaultVisibility); + }; + + const toggleColumn = (columnId) => { + if (TABLE_DEFAULTS.COLUMNS.FIXED.includes(columnId)) return; + + const newVisibility = { + ...columnVisibility, + [columnId]: !columnVisibility[columnId], + }; + + setColumnVisibility(newVisibility); + onColumnVisibilityChange?.(newVisibility); + }; + + return ( + + + + + Column Visibility + + + + + + + Reset + + + + + {Object.entries(TABLE_DEFAULTS.COLUMNS.COLUMN_GROUPS).map( + ([groupTitle, columns]) => { + // Calculer le nombre de colonnes cochées pour les évaluations + const isEvalGroup = groupTitle === "Evaluation Scores"; + const filteredColumns = columns.filter((col) => + filterableColumns.includes(col) + ); + const checkedCount = isEvalGroup + ? filteredColumns.filter((col) => columnVisibility[col]).length + : undefined; + const totalCount = isEvalGroup ? filteredColumns.length : undefined; + + return ( + + {filteredColumns.map((columnName) => { + const isFixed = + TABLE_DEFAULTS.COLUMNS.FIXED.includes(columnName); + return ( + toggleColumn(columnName)} + disabled={isFixed} + variant="tag" + /> + ); + })} + + ); + } + )} + + ); +}; + +export default ColumnSelector; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/DisplayOptions/DisplayOptions.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/DisplayOptions/DisplayOptions.js new file mode 100644 index 0000000000000000000000000000000000000000..8ec6c2bf0b68a6f2372d867a5a6487128956fb4c --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/DisplayOptions/DisplayOptions.js @@ -0,0 +1,238 @@ +import React from "react"; +import { Box, Typography } from "@mui/material"; +import TuneIcon from "@mui/icons-material/Tune"; +import CloseIcon from "@mui/icons-material/Close"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import { + TABLE_DEFAULTS, + ROW_SIZES, + SCORE_DISPLAY_OPTIONS, + RANKING_MODE_OPTIONS, +} from "../../constants/defaults"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; +import DropdownButton from "../shared/DropdownButton"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; + +const TableOptions = ({ + rowSize, + onRowSizeChange, + scoreDisplay = "normalized", + onScoreDisplayChange, + averageMode = "all", + onAverageModeChange, + rankingMode = "static", + onRankingModeChange, + hasChanges, + searchParams, + setSearchParams, + loading = false, +}) => { + const handleReset = () => { + onRowSizeChange(TABLE_DEFAULTS.ROW_SIZE); + onScoreDisplayChange(TABLE_DEFAULTS.SCORE_DISPLAY); + onAverageModeChange(TABLE_DEFAULTS.AVERAGE_MODE); + onRankingModeChange(TABLE_DEFAULTS.RANKING_MODE); + + const newParams = new URLSearchParams(searchParams); + ["rowSize", "scoreDisplay", "averageMode", "rankingMode"].forEach( + (param) => { + newParams.delete(param); + } + ); + setSearchParams(newParams); + }; + + return ( + + + + + Table Options + + + + + + + Reset + + + + + + + + + + {UI_TOOLTIPS.ROW_SIZE.title} + + + + + {Object.keys(ROW_SIZES).map((size) => ( + onRowSizeChange(size)} + variant="tag" + /> + ))} + + + + + + + {UI_TOOLTIPS.SCORE_DISPLAY.title} + + + + + {SCORE_DISPLAY_OPTIONS.map(({ value, label }) => ( + onScoreDisplayChange(value)} + variant="tag" + /> + ))} + + + + + + + {UI_TOOLTIPS.RANKING_MODE.title} + + + + + {RANKING_MODE_OPTIONS.map(({ value, label }) => ( + onRankingModeChange(value)} + variant="tag" + /> + ))} + + + + + + + {UI_TOOLTIPS.AVERAGE_SCORE.title} + + + + + onAverageModeChange("all")} + variant="tag" + /> + onAverageModeChange("visible")} + variant="tag" + /> + + + + + + ); +}; + +export default TableOptions; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/FilteredModelCount.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/FilteredModelCount.js new file mode 100644 index 0000000000000000000000000000000000000000..f35223166eb572d3d09527bd60129a006d85f7c8 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/FilteredModelCount.js @@ -0,0 +1,246 @@ +import React from "react"; +import { Box, Typography, Skeleton } from "@mui/material"; +import { useMemo } from "react"; +import { useLeaderboard } from "../../context/LeaderboardContext"; + +const useModelCount = ({ totalCount, filteredCount, data, table, loading }) => { + const { state } = useLeaderboard(); + const isOfficialProviderActive = state.filters.isOfficialProviderActive; + const { officialOnly: officialOnlyCounts } = state.filterCounts; + + return useMemo(() => { + if (loading) { + return { + displayCount: 0, + currentFilteredCount: 0, + totalPinnedCount: 0, + filteredPinnedCount: 0, + isOfficialProviderActive, + }; + } + const displayCount = isOfficialProviderActive + ? officialOnlyCounts.officialProviders + : totalCount; + + // Calculate total number of pinned models + const totalPinnedCount = + data?.filter((model) => model.isPinned)?.length || 0; + + // Get current filter criteria + const filterConfig = { + selectedPrecisions: state.filters.precisions, + selectedTypes: state.filters.types, + paramsRange: state.filters.paramsRange, + searchValue: state.filters.search, + selectedBooleanFilters: state.filters.booleanFilters, + isOfficialProviderActive: state.filters.isOfficialProviderActive, + }; + + // Check each pinned model if it would pass filters without its pinned status + const filteredPinnedCount = + data?.filter((model) => { + if (!model.isPinned) return false; + + // Check each filter criteria + + // Filter by official providers + if (filterConfig.isOfficialProviderActive) { + if ( + !model.features?.is_official_provider && + !model.metadata?.is_official_provider + ) { + return false; + } + } + + // Filter by precision + if (filterConfig.selectedPrecisions.length > 0) { + if ( + !filterConfig.selectedPrecisions.includes(model.model.precision) + ) { + return false; + } + } + + // Filter by type + if (filterConfig.selectedTypes.length > 0) { + const modelType = model.model.type?.toLowerCase().trim(); + if ( + !filterConfig.selectedTypes.some((type) => + modelType?.includes(type) + ) + ) { + return false; + } + } + + // Filter by parameters + const params = model.metadata.params_billions; + if ( + params < filterConfig.paramsRange[0] || + params >= filterConfig.paramsRange[1] + ) { + return false; + } + + // Filter by search + if (filterConfig.searchValue) { + const searchLower = filterConfig.searchValue.toLowerCase(); + const modelName = model.model.name.toLowerCase(); + if (!modelName.includes(searchLower)) { + return false; + } + } + + // Filter by boolean flags + if (filterConfig.selectedBooleanFilters.length > 0) { + if ( + !filterConfig.selectedBooleanFilters.every((filter) => { + const filterValue = + typeof filter === "object" ? filter.value : filter; + + // Maintainer's Highlight keeps positive logic + if (filterValue === "is_official_provider") { + return model.features[filterValue]; + } + + // For all other filters, invert the logic + if (filterValue === "is_not_available_on_hub") { + return model.features[filterValue]; + } + + return !model.features[filterValue]; + }) + ) { + return false; + } + } + + // If we get here, the model passes all filters + return true; + })?.length || 0; + + return { + displayCount, + currentFilteredCount: filteredCount, + totalPinnedCount, + filteredPinnedCount, + isOfficialProviderActive, + }; + }, [ + loading, + totalCount, + filteredCount, + data, + state.filters, + isOfficialProviderActive, + officialOnlyCounts.officialProviders, + ]); +}; + +const CountTypography = ({ + value, + color = "text.primary", + loading = false, + pinnedCount = 0, + filteredPinnedCount = 0, + showPinned = false, +}) => { + if (loading) { + return ( + + ); + } + + return ( + + + {value} + + {showPinned && pinnedCount > 0 && ( + + {`+${pinnedCount}`} + + )} + + ); +}; + +const FilteredModelCount = React.memo( + ({ + totalCount = 0, + filteredCount = 0, + hasFilterChanges = false, + loading = false, + data = [], + table = null, + }) => { + const { + displayCount, + currentFilteredCount, + totalPinnedCount, + filteredPinnedCount, + isOfficialProviderActive, + } = useModelCount({ + totalCount, + filteredCount, + data, + table, + loading, + }); + + const shouldHighlight = + !loading && hasFilterChanges && currentFilteredCount !== displayCount; + + // Always show pinned models when they exist + const pinnedToShow = totalPinnedCount; + + return ( + + 0} + /> + + + + ); + } +); + +FilteredModelCount.displayName = "FilteredModelCount"; + +export default FilteredModelCount; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/Filters.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/Filters.js new file mode 100644 index 0000000000000000000000000000000000000000..1fa0572d69fee9212d4bcd01058fc7acdb4d1de2 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/Filters.js @@ -0,0 +1,850 @@ +import React, { + useState, + useEffect, + useMemo, + useRef, + forwardRef, + useCallback, +} from "react"; +import { + Box, + Typography, + Collapse, + Slider, + Grid, + Accordion, + AccordionDetails, + alpha, + useTheme, + TextField, +} from "@mui/material"; +import { + TABLE_DEFAULTS, + BOOLEAN_FILTER_OPTIONS, + FILTER_PRECISIONS, +} from "../../constants/defaults"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import { MODEL_TYPE_ORDER, MODEL_TYPES } from "../../constants/modelTypes"; +import { useLeaderboard } from "../../context/LeaderboardContext"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { COLUMN_TOOLTIPS } from "../../constants/tooltips"; + +const getTooltipContent = (title) => { + switch (title) { + case "Model Type": + return COLUMN_TOOLTIPS.ARCHITECTURE; + case "Precision format": + return COLUMN_TOOLTIPS.PRECISION; + case "Flags": + return COLUMN_TOOLTIPS.FLAGS; + case "Parameters": + return COLUMN_TOOLTIPS.PARAMETERS; + default: + return null; + } +}; + +const FilterGroup = ({ + title, + tooltip, + children, + paramsRange, + onParamsRangeChange, +}) => { + const theme = useTheme(); + const [localParamsRange, setLocalParamsRange] = useState(paramsRange); + const stableTimerRef = useRef(null); + + // Handle local range change + const handleLocalRangeChange = useCallback((event, newValue) => { + setLocalParamsRange(newValue); + }, []); + + // Handle input change + const handleInputChange = useCallback( + (index) => (event) => { + const value = event.target.value === "" ? "" : Number(event.target.value); + if (value === "" || (value >= -1 && value <= 140)) { + const newRange = [...localParamsRange]; + newRange[index] = value; + setLocalParamsRange(newRange); + } + }, + [localParamsRange] + ); + + // Sync local state with props + useEffect(() => { + setLocalParamsRange(paramsRange); + }, [paramsRange]); + + // Propagate changes to parent after delay + useEffect(() => { + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + + stableTimerRef.current = setTimeout(() => { + if (Array.isArray(localParamsRange) && localParamsRange.length === 2) { + onParamsRangeChange(localParamsRange); + } + }, 300); + + return () => { + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + }; + }, [localParamsRange, onParamsRangeChange]); + + const renderContent = () => { + if (title === "Parameters") { + return ( + + + + + + + + (value === -1 ? "All" : `${value}B`)} + sx={{ + "& .MuiSlider-rail": { + height: 10, + backgroundColor: "background.paper", + border: "1px solid", + borderColor: "divider", + opacity: 1, + }, + "& .MuiSlider-track": { + height: 10, + border: "1px solid", + borderColor: (theme) => + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.3 : 0.5 + ), + backgroundColor: (theme) => + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.1 : 0.2 + ), + }, + "& .MuiSlider-thumb": { + width: 20, + height: 20, + backgroundColor: "background.paper", + border: "1px solid", + borderColor: "primary.main", + "&:hover, &.Mui-focusVisible": { + boxShadow: (theme) => + `0 0 0 8px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.08 : 0.16 + )}`, + }, + "&.Mui-active": { + boxShadow: (theme) => + `0 0 0 12px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.08 : 0.16 + )}`, + }, + }, + "& .MuiSlider-valueLabel": { + backgroundColor: theme.palette.primary.main, + }, + "& .MuiSlider-mark": { + width: 2, + height: 10, + backgroundColor: "divider", + }, + "& .MuiSlider-markLabel": { + fontSize: "0.875rem", + "&::after": { + content: '"B"', + marginLeft: "1px", + opacity: 0.5, + }, + '&[data-index="0"]::after': { + content: '""', + }, + }, + }} + /> + + ); + } + return ( + + {children} + + ); + }; + + return ( + + + + {title} + + + + {renderContent()} + + ); +}; + +const CustomCollapse = forwardRef((props, ref) => { + const { children, style = {}, ...other } = props; + const collapsedHeight = "0px"; + const timeout = 300; + + const wrapperRef = useRef(null); + const [animatedHeight, setAnimatedHeight] = useState( + props.in ? "auto" : collapsedHeight + ); + + useEffect(() => { + if (!wrapperRef.current) return; + + if (props.in) { + const contentHeight = wrapperRef.current.scrollHeight; + setAnimatedHeight(`${contentHeight}px`); + } else { + setAnimatedHeight(collapsedHeight); + } + }, [props.in, children]); + + const handleEntered = (node) => { + setAnimatedHeight("auto"); + if (props.onEntered) { + props.onEntered(node); + } + }; + + return ( + +
{children}
+
+ ); +}); + +const LeaderboardFilters = ({ + selectedPrecisions = FILTER_PRECISIONS, + onPrecisionsChange = () => {}, + selectedTypes = MODEL_TYPE_ORDER, + onTypesChange = () => {}, + paramsRange = [-1, 140], + onParamsRangeChange = () => {}, + selectedBooleanFilters = [], + onBooleanFiltersChange = () => {}, + data = [], + expanded, + onToggleExpanded, + loading = false, +}) => { + const [localParamsRange, setLocalParamsRange] = useState(paramsRange); + const stableTimerRef = useRef(null); + const { state, actions } = useLeaderboard(); + const { normal: filterCounts, officialOnly: officialOnlyCounts } = + state.filterCounts; + const isOfficialProviderActive = state.filters.isOfficialProviderActive; + const currentCounts = useMemo( + () => (isOfficialProviderActive ? officialOnlyCounts : filterCounts), + [isOfficialProviderActive, officialOnlyCounts, filterCounts] + ); + + useEffect(() => { + setLocalParamsRange(paramsRange); + }, [paramsRange]); + + // Clean up timer when component unmounts + useEffect(() => { + return () => { + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + }; + }, []); + + const handleParamsRangeChange = (event, newValue) => { + setLocalParamsRange(newValue); + }; + + const handleParamsRangeChangeCommitted = (event, newValue) => { + // Reset timer on each change + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + + // Update URL immediately + onParamsRangeChange(newValue); + + // Trigger data update after debounce + stableTimerRef.current = setTimeout(() => { + actions.updateFilteredData(); + }, TABLE_DEFAULTS.DEBOUNCE.SEARCH); + }; + + const handlePrecisionToggle = (precision) => { + const newPrecisions = selectedPrecisions.includes(precision) + ? selectedPrecisions.filter((p) => p !== precision) + : [...selectedPrecisions, precision]; + onPrecisionsChange(newPrecisions); + }; + + const handleBooleanFilterToggle = (filter) => { + const newFilters = selectedBooleanFilters.includes(filter) + ? selectedBooleanFilters.filter((f) => f !== filter) + : [...selectedBooleanFilters, filter]; + onBooleanFiltersChange(newFilters); + }; + + // Filter options based on their hide property + const showFilterOptions = BOOLEAN_FILTER_OPTIONS.filter( + (option) => !option.hide + ); + const hideFilterOptions = BOOLEAN_FILTER_OPTIONS.filter( + (option) => option.hide + ); + + const handleOfficialProviderToggle = () => { + actions.toggleOfficialProvider(); + }; + + return loading ? null : ( + + + + + + + + alpha(theme.palette.primary.main, 0.02), + border: "1px solid", + borderColor: (theme) => + alpha(theme.palette.primary.main, 0.2), + borderRadius: 1, + p: 3, + position: "relative", + width: "100%", + display: "flex", + flexDirection: "column", + "&:hover": { + borderColor: (theme) => + alpha(theme.palette.primary.main, 0.3), + backgroundColor: (theme) => + alpha(theme.palette.primary.main, 0.03), + }, + transition: (theme) => + theme.transitions.create( + ["border-color", "background-color"], + { + duration: theme.transitions.duration.short, + } + ), + }} + > + + Advanced Filters + + + + + + + {FILTER_PRECISIONS.map((precision) => ( + + handlePrecisionToggle(precision) + } + count={currentCounts.precisions[precision]} + showCheckbox={true} + /> + ))} + + + + + + + + + + + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.3 + : 0.5 + ), + backgroundColor: (theme) => + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.1 + : 0.2 + ), + }, + "& .MuiSlider-thumb": { + width: 20, + height: 20, + backgroundColor: "background.paper", + border: "1px solid", + borderColor: "primary.main", + "&:hover, &.Mui-focusVisible": { + boxShadow: (theme) => + `0 0 0 8px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.08 + : 0.16 + )}`, + }, + "&.Mui-active": { + boxShadow: (theme) => + `0 0 0 12px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.08 + : 0.16 + )}`, + }, + }, + "& .MuiSlider-mark": { + backgroundColor: "text.disabled", + height: 2, + width: 2, + borderRadius: "50%", + }, + "& .MuiSlider-markLabel": { + color: "text.secondary", + }, + }} + /> + + + + + + + {/* Deuxième ligne */} + + + + {MODEL_TYPE_ORDER.sort( + (a, b) => + MODEL_TYPES[a].order - MODEL_TYPES[b].order + ).map((type) => ( + { + const newTypes = selectedTypes.includes(type) + ? selectedTypes.filter((t) => t !== type) + : [...selectedTypes, type]; + onTypesChange(newTypes); + }} + count={currentCounts.modelTypes[type]} + variant="tag" + showCheckbox={true} + /> + ))} + + + + + + + + {hideFilterOptions.map((filter) => ( + { + const newFilters = + selectedBooleanFilters.includes( + filter.value + ) + ? selectedBooleanFilters.filter( + (f) => f !== filter.value + ) + : [ + ...selectedBooleanFilters, + filter.value, + ]; + onBooleanFiltersChange(newFilters); + }} + count={ + filter.value === "is_moe" + ? currentCounts.mixtureOfExperts + : filter.value === "is_flagged" + ? currentCounts.flagged + : filter.value === "is_merged" + ? currentCounts.merged + : filter.value === "is_not_available_on_hub" + ? currentCounts.notOnHub + : 0 + } + isHideFilter={false} + totalCount={data.length} + showCheckbox={true} + /> + ))} + + + + + + + + + + + alpha(theme.palette.secondary.main, 0.02), + border: "1px solid", + borderColor: (theme) => + alpha(theme.palette.secondary.main, 0.15), + borderRadius: 1, + p: 3, + position: "relative", + width: "100%", + display: "flex", + flexDirection: "column", + alignItems: "center", + justifyContent: "center", + textAlign: "center", + minHeight: "100%", + "&:hover": { + borderColor: (theme) => + alpha(theme.palette.secondary.main, 0.25), + backgroundColor: (theme) => + alpha(theme.palette.secondary.main, 0.03), + }, + transition: (theme) => + theme.transitions.create( + ["border-color", "background-color"], + { + duration: theme.transitions.duration.short, + } + ), + }} + > + + + Official Models + + + Show only models that are officially provided and + maintained by their original creators. + + + {showFilterOptions.map((filter) => ( + + handleBooleanFilterToggle(filter.value) + } + count={ + filter.value === "is_official_provider" + ? currentCounts.officialProviders + : 0 + } + showCheckbox={true} + variant="secondary" + /> + + + {( + filter.value === "is_official_provider" + ? isOfficialProviderActive + : selectedBooleanFilters.includes(filter.value) + ) + ? "Filter active" + : "Filter inactive"} + + + ))} + + + + + + + + + + ); +}; + +export default LeaderboardFilters; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/QuickFilters.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/QuickFilters.js new file mode 100644 index 0000000000000000000000000000000000000000..91d074c6375e8129eda09cea299b6aa36e26c3f9 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/QuickFilters.js @@ -0,0 +1,226 @@ +import React, { useCallback, useMemo } from "react"; +import { Box, Typography, Skeleton } from "@mui/material"; +import { alpha } from "@mui/material/styles"; +import { QUICK_FILTER_PRESETS } from "../../constants/quickFilters"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import { useLeaderboard } from "../../context/LeaderboardContext"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; + +const QuickFiltersTitle = ({ sx = {} }) => ( + + + Quick Filters + + + +); + +export const QuickFiltersSkeleton = () => ( + + ({ + xs: alpha(theme.palette.primary.main, 0.02), + lg: "transparent", + }), + borderColor: (theme) => ({ + xs: alpha(theme.palette.primary.main, 0.2), + lg: "transparent", + }), + border: "1px solid", + borderRadius: 1, + p: 3, + display: "flex", + flexDirection: { xs: "column", md: "column", lg: "row" }, + gap: 2, + mb: 2, + width: "100%", + }} + > + + + {[1, 2, 3, 4].map((i) => ( + + ))} + + + +); + +const QuickFilters = ({ totalCount = 0, loading = false }) => { + const { state, actions } = useLeaderboard(); + const { normal: filterCounts, officialOnly: officialOnlyCounts } = + state.filterCounts; + const isOfficialProviderActive = state.filters.isOfficialProviderActive; + const currentParams = state.filters.paramsRange; + + const currentCounts = useMemo( + () => (isOfficialProviderActive ? officialOnlyCounts : filterCounts), + [isOfficialProviderActive, officialOnlyCounts, filterCounts] + ); + + const modelSizePresets = useMemo( + () => + QUICK_FILTER_PRESETS.filter( + (preset) => preset.id !== "official_providers" + ), + [] + ); + + const officialProvidersPreset = useMemo( + () => + QUICK_FILTER_PRESETS.find((preset) => preset.id === "official_providers"), + [] + ); + + const handleSizePresetClick = useCallback( + (preset) => { + const isActive = + currentParams[0] === preset.filters.paramsRange[0] && + currentParams[1] === preset.filters.paramsRange[1]; + + if (isActive) { + actions.setFilter("paramsRange", [-1, 140]); // Reset to default + } else { + actions.setFilter("paramsRange", preset.filters.paramsRange); + } + }, + [currentParams, actions] + ); + + const getPresetCount = useCallback( + (preset) => { + const range = preset.id.split("_")[0]; + return currentCounts.parameterRanges[range] || 0; + }, + [currentCounts] + ); + + const handleOfficialProviderToggle = useCallback(() => { + actions.toggleOfficialProvider(); + }, [actions]); + + if (loading) { + return ; + } + + return ( + + ({ + xs: alpha(theme.palette.primary.main, 0.02), + lg: "transparent", + }), + borderColor: (theme) => ({ + xs: alpha(theme.palette.primary.main, 0.2), + lg: "transparent", + }), + border: "1px solid", + borderRadius: 1, + p: 3, + display: "flex", + flexDirection: { xs: "column", lg: "row" }, + alignItems: "center", + gap: 2, + width: "100%", + }} + > + + + + div": { + width: { xs: "100%", md: 0, lg: "auto" }, + flex: { + xs: "auto", + md: "1 1 0", + lg: "0 0 auto", + }, + }, + }} + > + {modelSizePresets.map((preset) => ( + handleSizePresetClick(preset)} + count={getPresetCount(preset)} + totalCount={totalCount} + /> + ))} + + + + {officialProvidersPreset && ( + + )} + + + + ); +}; + +QuickFilters.displayName = "QuickFilters"; + +export default React.memo(QuickFilters); diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/SearchBar.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/SearchBar.js new file mode 100644 index 0000000000000000000000000000000000000000..c32cd8f8640b0d2e8fa7c1928f76fcd8d53fe494 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/SearchBar.js @@ -0,0 +1,329 @@ +import React, { useState, useEffect } from "react"; +import { Box, InputBase, Typography, Paper, Skeleton } from "@mui/material"; + +import SearchIcon from "@mui/icons-material/Search"; +import FilterListIcon from "@mui/icons-material/FilterList"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; +import { useTheme } from "@mui/material/styles"; +import { generateSearchDescription } from "../../utils/searchUtils"; +import { + HIGHLIGHT_COLORS, + TABLE_DEFAULTS, + FILTER_PRECISIONS, +} from "../../constants/defaults"; +import { MODEL_TYPE_ORDER } from "../../constants/modelTypes"; +import { alpha } from "@mui/material/styles"; +import FilteredModelCount from "./FilteredModelCount"; +import { useLeaderboard } from "../../context/LeaderboardContext"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; + +export const SearchBarSkeleton = () => ( + + alpha(theme.palette.background.paper, 0.8), + borderRadius: 1, + border: (theme) => + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + display: "flex", + alignItems: "center", + px: 2, + gap: 2, + }} + > + + + + + + + + + + + Supports strict search and regex • Use semicolons for multiple terms + + + +); + +const SearchDescription = ({ searchValue }) => { + const searchGroups = generateSearchDescription(searchValue); + + if (!searchGroups || searchGroups.length === 0) return null; + + return ( + + + Showing models matching: + + {searchGroups.map(({ text, index }, i) => ( + + {i > 0 && ( + + and + + )} + + theme.palette.getContrastText( + HIGHLIGHT_COLORS[index % HIGHLIGHT_COLORS.length] + ), + padding: "2px 4px", + borderRadius: "4px", + fontSize: "0.85rem", + fontWeight: 500, + }} + > + {text} + + + ))} + + ); +}; + +const SearchBar = ({ + onToggleFilters, + filtersOpen, + loading = false, + data = [], + table = null, +}) => { + const theme = useTheme(); + const { state, actions } = useLeaderboard(); + const [localValue, setLocalValue] = useState(state.filters.search); + + useEffect(() => { + setLocalValue(state.filters.search); + }, [state.filters.search]); + + useEffect(() => { + const timer = setTimeout(() => { + if (localValue !== state.filters.search) { + actions.setFilter("search", localValue); + } + }, TABLE_DEFAULTS.DEBOUNCE.SEARCH); + + return () => clearTimeout(timer); + }, [localValue, state.filters.search, actions]); + + const handleLocalChange = (e) => { + setLocalValue(e.target.value); + }; + + const hasActiveFilters = + Object.values(state.filters.booleanFilters).some((value) => value) || + state.filters.precisions.length !== FILTER_PRECISIONS.length || + state.filters.types.length !== MODEL_TYPE_ORDER.length || + state.filters.paramsRange[0] !== -1 || + state.filters.paramsRange[1] !== 140 || + state.filters.isOfficialProviderActive; + + const shouldShowReset = localValue || hasActiveFilters; + + return ( + + + + + {!loading && ( + + )} + + {shouldShowReset && ( + { + setLocalValue(""); + actions.resetFilters(); + }} + sx={{ + display: "flex", + alignItems: "center", + gap: 0.5, + cursor: "pointer", + color: "text.secondary", + backgroundColor: "transparent", + border: "1px solid", + borderColor: "divider", + borderRadius: 1, + padding: "4px 8px", + "&:hover": { + backgroundColor: "action.hover", + color: "text.primary", + }, + userSelect: "none", + transition: "all 0.2s ease", + }} + > + + + Reset + + + )} + + + + Advanced Filters + + + + + + + {localValue ? ( + + ) : ( + + + Supports strict search and regex • Use semicolons for multiple + terms + + + )} + + + ); +}; + +export default SearchBar; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/useOfficialProvidersMode.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/useOfficialProvidersMode.js new file mode 100644 index 0000000000000000000000000000000000000000..729129cb3081bb525bcae2fc707f70658f74e778 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/useOfficialProvidersMode.js @@ -0,0 +1,130 @@ +import { useCallback, useState, useEffect, useRef } from "react"; +import { useSearchParams } from "react-router-dom"; + +const useRouterSearchParams = () => { + try { + return useSearchParams(); + } catch { + return [null, () => {}]; + } +}; + +export const useOfficialProvidersMode = () => { + const [isOfficialProviderActive, setIsOfficialProviderActive] = + useState(false); + const [searchParams, setSearchParams] = useRouterSearchParams(); + const normalFiltersRef = useRef(null); + const isInitialLoadRef = useRef(true); + const lastToggleSourceRef = useRef(null); + + // Effect to handle initial state and updates + useEffect(() => { + if (!searchParams) return; + + const filters = searchParams.get("filters"); + const isHighlighted = + filters?.includes("is_official_provider") || false; + + // On initial load + if (isInitialLoadRef.current) { + isInitialLoadRef.current = false; + + // If official mode is active at start, store filters without the highlightFilter + if (isHighlighted && filters) { + const initialNormalFilters = filters + .split(",") + .filter((f) => f !== "is_official_provider" && f !== "") + .filter(Boolean); + if (initialNormalFilters.length > 0) { + normalFiltersRef.current = initialNormalFilters.join(","); + } + } + + // Update state without triggering URL change + setIsOfficialProviderActive(isHighlighted); + return; + } + + // For subsequent changes + if (!isHighlighted && filters) { + normalFiltersRef.current = filters; + } + + setIsOfficialProviderActive(isHighlighted); + }, [searchParams]); + + const toggleOfficialProviderMode = useCallback( + (source = null) => { + if (!searchParams || !setSearchParams) return; + + // If source is the same as last time and last change was less than 100ms ago, ignore + const now = Date.now(); + if ( + source && + source === lastToggleSourceRef.current?.source && + now - (lastToggleSourceRef.current?.timestamp || 0) < 100 + ) { + return; + } + + const currentFiltersStr = searchParams.get("filters"); + const currentFilters = + currentFiltersStr?.split(",").filter(Boolean) || []; + const highlightFilter = "is_official_provider"; + const newSearchParams = new URLSearchParams(searchParams); + + if (currentFilters.includes(highlightFilter)) { + // Deactivating official provider mode + if (normalFiltersRef.current) { + const normalFilters = normalFiltersRef.current + .split(",") + .filter((f) => f !== highlightFilter && f !== "") + .filter(Boolean); + + if (normalFilters.length > 0) { + newSearchParams.set("filters", normalFilters.join(",")); + } else { + newSearchParams.delete("filters"); + } + } else { + const newFilters = currentFilters.filter( + (f) => f !== highlightFilter && f !== "" + ); + if (newFilters.length === 0) { + newSearchParams.delete("filters"); + } else { + newSearchParams.set("filters", newFilters.join(",")); + } + } + } else { + // Activating official provider mode + if (currentFiltersStr) { + normalFiltersRef.current = currentFiltersStr; + } + + const filtersToSet = [ + ...new Set([...currentFilters, highlightFilter]), + ].filter(Boolean); + newSearchParams.set("filters", filtersToSet.join(",")); + } + + // Update state immediately + setIsOfficialProviderActive(!currentFilters.includes(highlightFilter)); + + // Save source and timestamp of last change + lastToggleSourceRef.current = { + source, + timestamp: now, + }; + + // Update search params and let HashRouter handle the URL + setSearchParams(newSearchParams); + }, + [searchParams, setSearchParams] + ); + + return { + isOfficialProviderActive, + toggleOfficialProviderMode, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/usePresets.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/usePresets.js new file mode 100644 index 0000000000000000000000000000000000000000..35e17e54b0e1978635440908d3de6c742b37a856 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/usePresets.js @@ -0,0 +1,98 @@ +import { useCallback } from "react"; +import { QUICK_FILTER_PRESETS } from "../../../constants/quickFilters"; +import { TABLE_DEFAULTS } from "../../../constants/defaults"; + +const DEFAULT_FILTERS = { + searchValue: "", + selectedPrecisions: TABLE_DEFAULTS.SEARCH.PRECISIONS, + selectedTypes: TABLE_DEFAULTS.SEARCH.TYPES, + paramsRange: TABLE_DEFAULTS.SEARCH.PARAMS_RANGE, + selectedBooleanFilters: [], +}; + +export const usePresets = (searchFilters) => { + const handlePresetChange = useCallback( + (preset) => { + if (!searchFilters?.batchUpdateState) return; + + if (preset === null) { + // Reset with default values + searchFilters.batchUpdateState(DEFAULT_FILTERS, true); + return; + } + + // Apply preset with default values as base + const updates = { + ...DEFAULT_FILTERS, + ...preset.filters, + }; + + // Apply all changes at once + searchFilters.batchUpdateState(updates, true); + }, + [searchFilters] + ); + + const resetPreset = useCallback(() => { + handlePresetChange(null); + }, [handlePresetChange]); + + const getActivePreset = useCallback(() => { + // If searchFilters is not initialized yet, return null + if (!searchFilters) return null; + + // Dynamic detection of preset matching current filters + const currentParamsRange = Array.isArray(searchFilters.paramsRange) + ? searchFilters.paramsRange + : DEFAULT_FILTERS.paramsRange; + const currentBooleanFilters = Array.isArray( + searchFilters.selectedBooleanFilters + ) + ? searchFilters.selectedBooleanFilters + : DEFAULT_FILTERS.selectedBooleanFilters; + const currentPrecisions = Array.isArray(searchFilters.selectedPrecisions) + ? searchFilters.selectedPrecisions + : DEFAULT_FILTERS.selectedPrecisions; + const currentTypes = Array.isArray(searchFilters.selectedTypes) + ? searchFilters.selectedTypes + : DEFAULT_FILTERS.selectedTypes; + + return ( + QUICK_FILTER_PRESETS.find((preset) => { + const presetParamsRange = Array.isArray(preset.filters.paramsRange) + ? preset.filters.paramsRange + : DEFAULT_FILTERS.paramsRange; + const presetBooleanFilters = Array.isArray( + preset.filters.selectedBooleanFilters + ) + ? preset.filters.selectedBooleanFilters + : DEFAULT_FILTERS.selectedBooleanFilters; + + const paramsMatch = + JSON.stringify(presetParamsRange) === + JSON.stringify(currentParamsRange); + const booleanFiltersMatch = + JSON.stringify(presetBooleanFilters.sort()) === + JSON.stringify(currentBooleanFilters.sort()); + + // Check if other filters match default values + const precisionMatch = + JSON.stringify(currentPrecisions.sort()) === + JSON.stringify(DEFAULT_FILTERS.selectedPrecisions.sort()); + const typesMatch = + JSON.stringify(currentTypes.sort()) === + JSON.stringify(DEFAULT_FILTERS.selectedTypes.sort()); + + return ( + paramsMatch && booleanFiltersMatch && precisionMatch && typesMatch + ); + })?.id || null + ); + }, [searchFilters]); + + return { + activePreset: getActivePreset(), + handlePresetChange, + resetPreset, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/PerformanceMonitor.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/PerformanceMonitor.js new file mode 100644 index 0000000000000000000000000000000000000000..d3a20d28639f0d84835d854fe405795e14499d01 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/PerformanceMonitor.js @@ -0,0 +1,570 @@ +import React, { useEffect, useState, useRef } from "react"; +import { Box, Typography, Tooltip, useTheme } from "@mui/material"; +import NetworkCheckIcon from "@mui/icons-material/NetworkCheck"; +import MemoryIcon from "@mui/icons-material/Memory"; +import SpeedIcon from "@mui/icons-material/Speed"; +import GpuIcon from "@mui/icons-material/Memory"; +import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; + +const getGPUStats = () => { + try { + const canvas = document.createElement("canvas"); + const gl = + canvas.getContext("webgl") || canvas.getContext("experimental-webgl"); + + if (!gl) { + canvas.remove(); + return null; + } + + // Try to get GPU info extensions + const debugInfo = gl.getExtension("WEBGL_debug_renderer_info"); + + // Estimate GPU memory usage (very approximate) + let usedMemoryEstimate = 0; + + try { + // Create test texture + const testTexture = gl.createTexture(); + gl.bindTexture(gl.TEXTURE_2D, testTexture); + + // Test size: 1024x1024 RGBA + const testSize = 1024; + const pixels = new Uint8Array(testSize * testSize * 4); + gl.texImage2D( + gl.TEXTURE_2D, + 0, + gl.RGBA, + testSize, + testSize, + 0, + gl.RGBA, + gl.UNSIGNED_BYTE, + pixels + ); + + // Estimate memory usage (very approximate) + usedMemoryEstimate = (testSize * testSize * 4) / (1024 * 1024); // In MB + + gl.deleteTexture(testTexture); + gl.getExtension("WEBGL_lose_context")?.loseContext(); + } catch (e) { + console.warn("GPU memory estimation failed:", e); + } finally { + // Cleanup WebGL resources + const loseContext = gl.getExtension("WEBGL_lose_context"); + if (loseContext) loseContext.loseContext(); + gl.canvas.remove(); + } + + return { + vendor: debugInfo + ? gl.getParameter(debugInfo.UNMASKED_VENDOR_WEBGL) + : "Unknown", + renderer: debugInfo + ? gl.getParameter(debugInfo.UNMASKED_RENDERER_WEBGL) + : "Unknown", + usedMemory: Math.round(usedMemoryEstimate), + }; + } catch (e) { + return null; + } +}; + +const MetricBox = ({ icon, label, value, tooltip }) => { + const theme = useTheme(); + return ( + + {icon} + + + {label} + + + {React.isValidElement(value) ? value : {value}} + + {tooltip && ( + + + + + + )} + + ); +}; + +const formatNumber = (num) => { + return num.toString().replace(/\B(?=(\d{3})+(?!\d))/g, " "); +}; + +const PerformanceMonitor = () => { + const theme = useTheme(); + + const [stats, setStats] = useState({ + fps: 0, + memory: { + usedJSHeapSize: 0, + totalJSHeapSize: 0, + }, + renders: 0, + network: { + transferSize: 0, + decodedBodySize: 0, + compressionRatio: 0, + }, + gpu: getGPUStats(), + fcp: null, + }); + const [isVisible, setIsVisible] = useState( + process.env.NODE_ENV === "development" + ); + const renderCountRef = useRef(0); + const originalCreateElementRef = useRef(null); + + useEffect(() => { + const handleKeyDown = (event) => { + // Ignore if user is in an input field + if ( + event.target.tagName === "INPUT" || + event.target.tagName === "TEXTAREA" + ) { + return; + } + + if (event.key === "p" || event.key === "P") { + setIsVisible((prev) => !prev); + } + }; + + window.addEventListener("keydown", handleKeyDown); + return () => window.removeEventListener("keydown", handleKeyDown); + }, []); + + useEffect(() => { + let frameCount = 0; + let lastTime = performance.now(); + let animationFrameId; + + const getNetworkStats = () => { + const resources = performance.getEntriesByType("resource"); + const navigation = performance.getEntriesByType("navigation")[0]; + + let totalTransferSize = navigation ? navigation.transferSize : 0; + let totalDecodedSize = navigation ? navigation.decodedBodySize : 0; + + resources.forEach((resource) => { + totalTransferSize += resource.transferSize || 0; + totalDecodedSize += resource.decodedBodySize || 0; + }); + + const compressionRatio = totalDecodedSize + ? Math.round((1 - totalTransferSize / totalDecodedSize) * 100) + : 0; + + return { + transferSize: Math.round(totalTransferSize / 1024), + decodedBodySize: Math.round(totalDecodedSize / 1024), + compressionRatio, + }; + }; + + // Save original function + originalCreateElementRef.current = React.createElement; + + // Replace createElement + React.createElement = function (...args) { + renderCountRef.current++; + return originalCreateElementRef.current.apply(this, args); + }; + + const updateStats = () => { + frameCount++; + const now = performance.now(); + const delta = now - lastTime; + + if (delta >= 1000) { + const fps = Math.round((frameCount * 1000) / delta); + + const memory = window.performance?.memory + ? { + usedJSHeapSize: Math.round( + window.performance.memory.usedJSHeapSize / 1048576 + ), + totalJSHeapSize: Math.round( + window.performance.memory.totalJSHeapSize / 1048576 + ), + } + : null; + + const network = getNetworkStats(); + const gpu = getGPUStats(); + + setStats((prev) => ({ + ...prev, + fps, + memory: memory || prev.memory, + renders: renderCountRef.current, + network, + gpu, + })); + + frameCount = 0; + lastTime = now; + } + + animationFrameId = requestAnimationFrame(updateStats); + }; + + updateStats(); + + return () => { + cancelAnimationFrame(animationFrameId); + // Restore original function + if (originalCreateElementRef.current) { + React.createElement = originalCreateElementRef.current; + } + // Clean up counters + renderCountRef.current = 0; + delete window.__REACT_RENDERS__; + }; + }, []); + + useEffect(() => { + // Add FCP observer + if (window.PerformanceObserver) { + try { + const fcpObserver = new PerformanceObserver((entryList) => { + const entries = entryList.getEntries(); + if (entries.length > 0) { + const fcp = entries[0].startTime; + setStats((prev) => ({ + ...prev, + fcp, + })); + } + }); + + fcpObserver.observe({ entryTypes: ["paint"] }); + return () => fcpObserver.disconnect(); + } catch (e) { + console.warn("FCP observation failed:", e); + } + } + }, []); + + const getFpsColor = (fps) => { + if (fps >= 55) return "#4CAF50"; + if (fps >= 30) return "#FFC107"; + return "#F44336"; + }; + + return isVisible ? ( + + + + Performances{" "} + dev only + + + {/* Performance Metrics */} + + + } + label="FPS" + value={ + + {stats.fps} + + } + tooltip="Frames Per Second - Indicates how smooth the UI is running" + /> + + {stats.fcp !== null && ( + + } + label="FCP" + value={ + + {Math.round(stats.fcp)}ms + + } + tooltip="First Contentful Paint - Time until first content is rendered" + /> + )} + + ⚛️ + + } + label="React" + value={ + + {formatNumber(stats.renders)} + cycles + + } + tooltip="Total number of React render cycles" + /> + + + {/* Memory Metrics */} + + {window.performance?.memory && ( + } + label="Mem" + value={ + + {stats.memory.usedJSHeapSize} + / + {stats.memory.totalJSHeapSize} + MB + + } + tooltip="JavaScript heap memory usage (Used / Total)" + /> + )} + {stats.gpu && ( + } + label="GPU" + value={ + + {stats.gpu.usedMemory} + MB + + } + tooltip="Estimated GPU memory usage" + /> + )} + + + {/* Network Metrics */} + + + } + label="Net" + value={ + + {stats.network.transferSize} + KB + + } + tooltip="Network data transferred" + /> + } + label="Size" + value={ + + {formatNumber(stats.network.decodedBodySize)} + KB + 0 ? "#81C784" : "inherit", + fontSize: "0.7rem", + opacity: 0.8, + ml: 1, + }} + > + (-{stats.network.compressionRatio}%) + + + } + tooltip="Total decoded size and compression ratio" + /> + + + Press "P" to show/hide + + + +
+ ) : null; +}; + +export default React.memo(PerformanceMonitor); diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/Table.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/Table.js new file mode 100644 index 0000000000000000000000000000000000000000..b9279247881135a2d4cf2122ed542474fc20f6be --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/Table.js @@ -0,0 +1,720 @@ +import React, { useRef, useCallback, useMemo } from "react"; +import { + Paper, + Table, + TableContainer, + TableHead, + TableBody, + TableRow, + TableCell, + Box, + Typography, + Skeleton, +} from "@mui/material"; +import { flexRender } from "@tanstack/react-table"; +import { useVirtualizer } from "@tanstack/react-virtual"; +import KeyboardArrowUpIcon from "@mui/icons-material/KeyboardArrowUp"; +import KeyboardArrowDownIcon from "@mui/icons-material/KeyboardArrowDown"; +import UnfoldMoreIcon from "@mui/icons-material/UnfoldMore"; +import SearchOffIcon from "@mui/icons-material/SearchOff"; +import { + TABLE_DEFAULTS, + ROW_SIZES, + SKELETON_COLUMNS, +} from "../../constants/defaults"; +import { alpha } from "@mui/material/styles"; +import TableOptions from "../DisplayOptions/DisplayOptions"; +import ColumnSelector from "../ColumnSelector/ColumnSelector"; + +const NoResultsFound = () => ( + + + + No models found + + + Try modifying your filters or search to see more models. + + +); + +const TableSkeleton = ({ rowSize = "normal" }) => { + const currentRowHeight = Math.floor(ROW_SIZES[rowSize]); + const headerHeight = Math.floor(currentRowHeight * 1.25); + const skeletonRows = 10; + + return ( + + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + borderRadius: 1, + }} + > + + + + {SKELETON_COLUMNS.map((width, index) => ( + 3 ? "right" : "left", + borderRight: (theme) => `1px solid ${theme.palette.divider}`, + "&:last-child": { + borderRight: "none", + }, + position: "sticky", + top: 0, + backgroundColor: (theme) => theme.palette.background.paper, + zIndex: 2, + }} + /> + ))} + + + + {[...Array(skeletonRows)].map((_, index) => ( + + index % 2 === 0 ? "transparent" : theme.palette.action.hover, + }} + > + {SKELETON_COLUMNS.map((width, cellIndex) => ( + + `1px solid ${theme.palette.divider}`, + "&:last-child": { + borderRight: "none", + }, + }} + > + 3 ? "auto" : 0, + backgroundColor: (theme) => + alpha(theme.palette.text.primary, 0.11), + "&::after": { + background: (theme) => + `linear-gradient(90deg, ${alpha( + theme.palette.text.primary, + 0.11 + )}, ${alpha( + theme.palette.text.primary, + 0.14 + )}, ${alpha(theme.palette.text.primary, 0.11)})`, + }, + }} + /> + + ))} + + ))} + +
+
+ ); +}; + +const TableControls = React.memo( + ({ + loading, + rowSize, + onRowSizeChange, + scoreDisplay, + onScoreDisplayChange, + averageMode, + onAverageModeChange, + rankingMode, + onRankingModeChange, + hasTableOptionsChanges, + searchParams, + setSearchParams, + table, + handleColumnReset, + hasColumnFilterChanges, + onColumnVisibilityChange, + }) => ( + + + + + ) +); + +TableControls.displayName = "TableControls"; + +const LeaderboardTable = ({ + table, + rowSize = "normal", + loading = false, + hasTableOptionsChanges, + hasColumnFilterChanges, + onColumnVisibilityChange, + scoreDisplay, + onScoreDisplayChange, + averageMode, + onAverageModeChange, + rankingMode, + onRankingModeChange, + onRowSizeChange, + searchParams, + setSearchParams, + pinnedModels = [], +}) => { + const { rows } = table.getRowModel(); + const parentRef = useRef(null); + + const currentRowHeight = useMemo(() => ROW_SIZES[rowSize], [rowSize]); + const headerHeight = useMemo( + () => Math.floor(currentRowHeight * 1.25), + [currentRowHeight] + ); + + // Separate pinned rows from normal rows while preserving original order + const pinnedRows = useMemo(() => { + const pinnedModelRows = rows.filter((row) => row.original.isPinned); + // Sort pinned models according to their original order in pinnedModels + return pinnedModelRows.sort((a, b) => { + const aIndex = pinnedModels.indexOf(a.original.id); + const bIndex = pinnedModels.indexOf(b.original.id); + return aIndex - bIndex; + }); + }, [rows, pinnedModels]); + + const unpinnedRows = useMemo( + () => rows.filter((row) => !row.original.isPinned), + [rows] + ); + const pinnedHeight = useMemo( + () => pinnedRows.length * currentRowHeight, + [pinnedRows.length, currentRowHeight] + ); + + const virtualizerOptions = useMemo( + () => ({ + count: unpinnedRows.length, + getScrollElement: () => parentRef.current, + estimateSize: () => currentRowHeight, + overscan: 15, + scrollMode: "sync", + scrollPaddingStart: pinnedHeight, + scrollPaddingEnd: 0, + initialRect: { width: 0, height: currentRowHeight * 15 }, + }), + [currentRowHeight, unpinnedRows.length, pinnedHeight] + ); + + const rowVirtualizer = useVirtualizer(virtualizerOptions); + + const virtualRows = rowVirtualizer.getVirtualItems(); + + // Adjust paddings to account for pinned rows + const paddingTop = virtualRows.length > 0 ? virtualRows[0].start : 0; + const paddingBottom = + virtualRows.length > 0 + ? unpinnedRows.length * currentRowHeight - + virtualRows[virtualRows.length - 1].end + : 0; + + // Handle column reset + const handleColumnReset = useCallback(() => { + onColumnVisibilityChange(TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE); + }, [onColumnVisibilityChange]); + + const cellStyles = (theme) => ({ + borderRight: `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + "&:last-child": { + borderRight: "none", + }, + whiteSpace: "nowrap", + overflow: "hidden", + textOverflow: "ellipsis", + padding: "8px 16px", + }); + + const headerCellStyles = (theme) => ({ + ...cellStyles(theme), + padding: "6px 16px", + height: "36px", + position: "sticky !important", + top: 0, + zIndex: 10, + "& > .header-content": { + display: "flex", + alignItems: "center", + width: "100%", + gap: "4px", + flexDirection: "row", + }, + }); + + const getSortingIcon = (column) => { + if ( + column.id === "rank" || + column.id === "model_type" || + column.id === "isPinned" + ) { + return null; + } + + if (!column.getIsSorted()) { + return ; + } + return column.getIsSorted() === "desc" ? ( + + ) : ( + + ); + }; + + const renderHeaderContent = (header) => { + const sortIcon = getSortingIcon(header.column); + return ( + + {flexRender(header.column.columnDef.header, header.getContext())} + + {sortIcon || } + + + ); + }; + + const renderRow = (row, isSticky = false, stickyIndex = 0) => { + // Get row index in the sorted data model + const sortedIndex = table + .getSortedRowModel() + .rows.findIndex((r) => r.id === row.id); + + return ( + ({ + height: `${currentRowHeight}px !important`, + backgroundColor: isSticky + ? theme.palette.background.paper + : (sortedIndex + 1) % 2 === 0 + ? "transparent" + : alpha(theme.palette.mode === "dark" ? "#fff" : "#000", 0.02), + position: isSticky ? "sticky" : "relative", + top: isSticky + ? `${headerHeight + stickyIndex * currentRowHeight}px` + : "auto", + zIndex: isSticky ? 2 : 1, + boxShadow: isSticky + ? `0 1px 1px ${alpha( + theme.palette.common.black, + theme.palette.mode === "dark" ? 0.1 : 0.05 + )}` + : "none", + "&::after": isSticky + ? { + content: '""', + position: "absolute", + left: 0, + right: 0, + height: "1px", + bottom: -1, + backgroundColor: alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + ), + zIndex: 1, + } + : {}, + })} + > + {row.getVisibleCells().map((cell) => ( + ({ + width: `${cell.column.columnDef.size}px !important`, + minWidth: `${cell.column.columnDef.size}px !important`, + height: `${currentRowHeight}px`, + backgroundColor: isSticky + ? theme.palette.background.paper + : "inherit", + borderBottom: isSticky + ? "none" + : `1px solid ${theme.palette.divider}`, + ...cellStyles(theme), + ...(cell.column.columnDef.meta?.cellStyle?.(cell.getValue()) || + {}), + "& .MuiBox-root": { + overflow: "visible", + }, + })} + > + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + ); + }; + + if (!loading && (!rows || rows.length === 0)) { + return ( + + + + + + + ); + } + + if (loading) { + return ( + + + + + + + ); + } + + return ( + + + + ({ + height: "100%", + overflow: "auto", + border: "none", + boxShadow: "none", + "&::-webkit-scrollbar": { + width: "8px", + height: "8px", + }, + "&::-webkit-scrollbar-thumb": { + backgroundColor: alpha( + theme.palette.common.black, + theme.palette.mode === "dark" ? 0.4 : 0.2 + ), + borderRadius: "4px", + }, + "&::-webkit-scrollbar-corner": { + backgroundColor: theme.palette.background.paper, + }, + willChange: "transform", + transform: "translateZ(0)", + WebkitOverflowScrolling: "touch", + scrollBehavior: "auto", + })} + > + 0 ? "fixed" : "fixed", + border: "none", + "& td, & th": + pinnedRows.length > 0 + ? { + width: `${100 / table.getAllColumns().length}%`, + } + : {}, + }} + > + + {table.getAllColumns().map((column, index) => ( + + ))} + + + theme.palette.background.paper, + "& th": { + backgroundColor: (theme) => theme.palette.background.paper, + }, + }} + > + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + ({ + cursor: header.column.getCanSort() + ? "pointer" + : "default", + width: header.column.columnDef.size, + minWidth: header.column.columnDef.size, + ...headerCellStyles(theme), + textAlign: "left", + fontWeight: header.column.getIsSorted() ? 700 : 400, + userSelect: "none", + height: `${headerHeight}px`, + padding: `${headerHeight * 0.25}px 16px`, + backgroundColor: theme.palette.background.paper, + })} + > + {renderHeaderContent(header)} + + ))} + + ))} + + + + {/* Pinned rows */} + {pinnedRows.map((row, index) => renderRow(row, true, index))} + + {/* Padding for virtualized rows */} + {paddingTop > 0 && ( + + + + )} + + {/* Virtualized unpinned rows */} + {virtualRows.map((virtualRow) => { + const row = unpinnedRows[virtualRow.index]; + if (!row) return null; + return renderRow(row); + })} + + {/* Bottom padding */} + {paddingBottom > 0 && ( + + + + )} + +
+
+
+
+ ); +}; + +export { TableSkeleton }; +export default LeaderboardTable; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useDataProcessing.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useDataProcessing.js new file mode 100644 index 0000000000000000000000000000000000000000..6f5463755578ae260d6639403706e5f6071eb614 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useDataProcessing.js @@ -0,0 +1,161 @@ +import { useMemo } from "react"; +import { + useReactTable, + getSortedRowModel, + getCoreRowModel, + getFilteredRowModel, +} from "@tanstack/react-table"; +import { createColumns } from "../../../utils/columnUtils"; +import { + useAverageRange, + useColorGenerator, + useProcessedData, + useFilteredData, + useColumnVisibility, +} from "../../../hooks/useDataUtils"; + +export const useDataProcessing = ( + data, + searchValue, + selectedPrecisions, + selectedTypes, + paramsRange, + selectedBooleanFilters, + sorting, + rankingMode, + averageMode, + visibleColumns, + scoreDisplay, + pinnedModels, + onTogglePin, + setSorting, + isOfficialProviderActive +) => { + // Call hooks directly at root level + const { minAverage, maxAverage } = useAverageRange(data); + const getColorForValue = useColorGenerator(minAverage, maxAverage); + const processedData = useProcessedData(data, averageMode, visibleColumns); + const columnVisibility = useColumnVisibility(visibleColumns); + + // Memoize filters + const filterConfig = useMemo( + () => ({ + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels, + isOfficialProviderActive, + }), + [ + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels, + isOfficialProviderActive, + ] + ); + + // Call useFilteredData at root level + const filteredData = useFilteredData( + processedData, + filterConfig.selectedPrecisions, + filterConfig.selectedTypes, + filterConfig.paramsRange, + filterConfig.searchValue, + filterConfig.selectedBooleanFilters, + filterConfig.rankingMode, + filterConfig.pinnedModels, + filterConfig.isOfficialProviderActive + ); + + // Memoize columns creation + const columns = useMemo( + () => + createColumns( + getColorForValue, + scoreDisplay, + columnVisibility, + data.length, + averageMode, + searchValue, + rankingMode, + onTogglePin + ), + [ + getColorForValue, + scoreDisplay, + columnVisibility, + data.length, + averageMode, + searchValue, + rankingMode, + onTogglePin, + ] + ); + + // Memoize table configuration + const tableConfig = useMemo( + () => ({ + data: filteredData, + columns, + state: { + sorting: Array.isArray(sorting) ? sorting : [], + columnVisibility, + }, + getCoreRowModel: getCoreRowModel(), + getFilteredRowModel: getFilteredRowModel(), + getSortedRowModel: getSortedRowModel(), + onSortingChange: setSorting, + enableColumnVisibility: true, + defaultColumn: { + sortingFn: (rowA, rowB, columnId) => { + const isDesc = sorting?.[0]?.desc; + + if (rowA.original.isPinned && rowB.original.isPinned) { + return ( + pinnedModels.indexOf(rowA.original.id) - + pinnedModels.indexOf(rowB.original.id) + ); + } + + if (isDesc) { + if (rowA.original.isPinned) return -1; + if (rowB.original.isPinned) return 1; + } else { + if (rowA.original.isPinned) return -1; + if (rowB.original.isPinned) return 1; + } + + const aValue = rowA.getValue(columnId); + const bValue = rowB.getValue(columnId); + + if (typeof aValue === "number" && typeof bValue === "number") { + return aValue - bValue; + } + + return String(aValue).localeCompare(String(bValue)); + }, + }, + }), + [filteredData, columns, sorting, columnVisibility, pinnedModels, setSorting] + ); + + const table = useReactTable(tableConfig); + + return { + table, + minAverage, + maxAverage, + getColorForValue, + processedData, + filteredData, + columns, + columnVisibility, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useSorting.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useSorting.js new file mode 100644 index 0000000000000000000000000000000000000000..b6e24b528b4938ecd52e2a61624e028d3ffc8dc0 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useSorting.js @@ -0,0 +1,16 @@ +export const typeColumnSort = (rowA, rowB) => { + const aValue = rowA.getValue("model_type"); + const bValue = rowB.getValue("model_type"); + + // If both values are arrays, compare their first elements + if (Array.isArray(aValue) && Array.isArray(bValue)) { + return String(aValue[0] || "").localeCompare(String(bValue[0] || "")); + } + + // If one is array and other isn't, array comes first + if (Array.isArray(aValue)) return -1; + if (Array.isArray(bValue)) return 1; + + // If neither is array, compare as strings + return String(aValue || "").localeCompare(String(bValue || "")); +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/shared/DropdownButton.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/shared/DropdownButton.js new file mode 100644 index 0000000000000000000000000000000000000000..2badebd0fb115b1a0f78ff81abd41c2b384c9233 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/shared/DropdownButton.js @@ -0,0 +1,137 @@ +import React, { useState } from "react"; +import { Box, Popover, Portal, Typography, Skeleton } from "@mui/material"; +import { useTheme } from "@mui/material/styles"; +import { commonStyles } from "../../styles/common"; + +const DropdownButton = ({ + label, + icon: Icon, + closeIcon: CloseIcon, + hasChanges = false, + children, + defaultWidth = 340, + paperProps = {}, + buttonSx = {}, + loading = false, +}) => { + const theme = useTheme(); + const [anchorEl, setAnchorEl] = useState(null); + + const handleClick = (event) => { + event.stopPropagation(); + setAnchorEl(event.currentTarget); + }; + + const handleClose = (event) => { + if (event) { + event.stopPropagation(); + } + setAnchorEl(null); + }; + + if (loading) { + return ( + + ); + } + + return ( + + + {Boolean(anchorEl) && CloseIcon ? ( + + ) : ( + + )} + + {label} + + + + + theme.palette.mode === "light" + ? "rgba(0, 0, 0, 0.12)" + : "rgba(255, 255, 255, 0.12)", + borderRadius: 1, + position: "relative", + boxShadow: (theme) => + `0px 4px 20px ${ + theme.palette.mode === "light" + ? "rgba(0, 0, 0, 0.1)" + : "rgba(255, 255, 255, 0.1)" + }`, + ...paperProps.sx, + }, + ...paperProps, + }} + anchorOrigin={{ + vertical: "bottom", + horizontal: "right", + }} + transformOrigin={{ + vertical: "top", + horizontal: "right", + }} + slotProps={{ + backdrop: { + sx: { + backgroundColor: "transparent", + }, + }, + }} + > + {children} + + + + ); +}; + +export default DropdownButton; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/defaults.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/defaults.js new file mode 100644 index 0000000000000000000000000000000000000000..edfa7be429120862896260cdecd4b7a752e90e5a --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/defaults.js @@ -0,0 +1,380 @@ +import { MODEL_TYPE_ORDER } from "./modelTypes"; + +// Time constants (in milliseconds) +const TIME = { + CACHE_DURATION: 5 * 60 * 1000, // 5 minutes + DEBOUNCE: { + URL_PARAMS: 100, + SEARCH: 150, + RANGE_PICKER: 350, + }, +}; + +// Display constants +const DISPLAY = { + ROW_SIZES: { + normal: 45, + large: 60, + }, + SCORE_DISPLAY_OPTIONS: [ + { value: "normalized", label: "Normalized" }, + { value: "raw", label: "Raw" }, + ], + RANKING_MODE_OPTIONS: [ + { value: "static", label: "Static" }, + { value: "dynamic", label: "Dynamic" }, + ], +}; + +// Filter constants +const FILTERS = { + PRECISIONS: ["bfloat16", "float16", "4bit"], + SUBMISSION_PRECISIONS: [ + { value: "float16", label: "float16" }, + { value: "bfloat16", label: "bfloat16" }, + { value: "8bit", label: "8-bit" }, + { value: "4bit", label: "4-bit" }, + { value: "gptq", label: "GPTQ" }, + ], + PARAMS_RANGE: [-1, 140], + BOOLEAN_OPTIONS: [ + { + value: "is_moe", + label: "Mixture of Experts", + hide: true, + }, + { + value: "is_merged", + label: "Merged model", + hide: true, + }, + { + value: "is_flagged", + label: "Potentially contaminated model", + hide: true, + }, + { + value: "is_not_available_on_hub", + label: "Unavailable model", + hide: true, + }, + { + value: "is_official_provider", + label: "Only Official Providers", + hide: false, + }, + ], + HIGHLIGHT_OPTIONS: [ + { + value: "is_official_provider", + label: "Only Official Providers", + }, + ], +}; + +// Column size constants +const COLUMN_SIZES = { + RANK: 65, + TYPE_ICON: 65, + MODEL: 400, + AVERAGE_SCORE: 150, + BENCHMARK: 110, + CO2_COST: 140, + HUB_HEARTS: 140, + ARCHITECTURE: 210, + PRECISION: 140, + PARAMS: 160, + LICENSE: 160, + UPLOAD_DATE: 160, + SUBMISSION_DATE: 200, + GENERATION: 160, + BASE_MODEL: 390, + HUB_AVAILABILITY: 180, + OFFICIAL_PROVIDER: 240, + MOE: 200, + FLAG_STATUS: 160, + CHAT_TEMPLATE: 140, +}; + +// Column definitions with organized structure +const COLUMNS = { + FIXED: { + rank: { + group: "fixed", + size: COLUMN_SIZES.RANK, + defaultVisible: true, + label: "Rank", + }, + "model.type_icon": { + group: "fixed", + size: COLUMN_SIZES.TYPE_ICON, + defaultVisible: true, + label: "Type", + }, + id: { + group: "fixed", + size: COLUMN_SIZES.MODEL, + defaultVisible: true, + label: "Model", + }, + "model.average_score": { + group: "fixed", + size: COLUMN_SIZES.AVERAGE_SCORE, + defaultVisible: true, + label: "Average Score", + }, + }, + EVALUATION: { + "evaluations.ifeval.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "IFEval", + }, + "evaluations.bbh.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "BBH", + }, + "evaluations.math.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "MATH", + }, + "evaluations.gpqa.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "GPQA", + }, + "evaluations.musr.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "MUSR", + }, + "evaluations.mmlu_pro.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "MMLU-PRO", + }, + }, + MODEL_INFO: { + "metadata.co2_cost": { + group: "model_info", + size: COLUMN_SIZES.CO2_COST, + defaultVisible: true, + label: "CO₂ Cost (kg)", + }, + "metadata.hub_hearts": { + group: "model_info", + size: COLUMN_SIZES.HUB_HEARTS, + defaultVisible: false, + label: "Hub ❤️", + }, + "model.architecture": { + group: "model_info", + size: COLUMN_SIZES.ARCHITECTURE, + defaultVisible: false, + label: "Architecture", + }, + "model.precision": { + group: "model_info", + size: COLUMN_SIZES.PRECISION, + defaultVisible: false, + label: "Precision", + }, + "metadata.params_billions": { + group: "model_info", + size: COLUMN_SIZES.PARAMS, + defaultVisible: false, + label: "Parameters (B)", + }, + "metadata.hub_license": { + group: "model_info", + size: COLUMN_SIZES.LICENSE, + defaultVisible: false, + label: "License", + }, + "model.has_chat_template": { + group: "model_info", + size: COLUMN_SIZES.CHAT_TEMPLATE, + defaultVisible: false, + label: "Chat Template", + }, + }, + ADDITIONAL_INFO: { + "metadata.upload_date": { + group: "additional_info", + size: COLUMN_SIZES.UPLOAD_DATE, + defaultVisible: false, + label: "Upload Date", + }, + "metadata.submission_date": { + group: "additional_info", + size: COLUMN_SIZES.SUBMISSION_DATE, + defaultVisible: false, + label: "Submission Date", + }, + "metadata.generation": { + group: "additional_info", + size: COLUMN_SIZES.GENERATION, + defaultVisible: false, + label: "Generation", + }, + "metadata.base_model": { + group: "additional_info", + size: COLUMN_SIZES.BASE_MODEL, + defaultVisible: false, + label: "Base Model", + }, + "features.is_not_available_on_hub": { + group: "additional_info", + size: COLUMN_SIZES.HUB_AVAILABILITY, + defaultVisible: false, + label: "Hub Availability", + }, + "features.is_official_provider": { + group: "additional_info", + size: COLUMN_SIZES.OFFICIAL_PROVIDER, + defaultVisible: false, + label: "Only Official Providers", + }, + "features.is_moe": { + group: "additional_info", + size: COLUMN_SIZES.MOE, + defaultVisible: false, + label: "Mixture of Experts", + }, + "features.is_flagged": { + group: "additional_info", + size: COLUMN_SIZES.FLAG_STATUS, + defaultVisible: false, + label: "Flag Status", + }, + }, +}; + +// Combine all columns for backward compatibility +const ALL_COLUMNS = { + ...COLUMNS.FIXED, + ...COLUMNS.EVALUATION, + ...COLUMNS.MODEL_INFO, + ...COLUMNS.ADDITIONAL_INFO, +}; + +// Column definitions for external use (maintaining the same interface) +const COLUMN_DEFINITIONS = { + ALL_COLUMNS, + COLUMN_GROUPS: { + "Evaluation Scores": Object.keys(COLUMNS.EVALUATION), + "Model Information": Object.keys(COLUMNS.MODEL_INFO), + "Additional Information": Object.keys(COLUMNS.ADDITIONAL_INFO), + }, + COLUMN_LABELS: Object.entries(ALL_COLUMNS).reduce((acc, [key, value]) => { + acc[key] = value.label; + return acc; + }, {}), + DEFAULT_VISIBLE: Object.entries(ALL_COLUMNS) + .filter(([_, value]) => value.defaultVisible) + .map(([key]) => key), + + // Remettre les getters nécessaires + get FIXED() { + return Object.entries(ALL_COLUMNS) + .filter(([_, def]) => def.group === "fixed") + .map(([key]) => key); + }, + + get EVALUATION() { + return Object.entries(ALL_COLUMNS) + .filter(([_, def]) => def.group === "evaluation") + .map(([key]) => key); + }, + + get OPTIONAL() { + return Object.entries(ALL_COLUMNS) + .filter(([_, def]) => def.group !== "fixed" && def.group !== "evaluation") + .map(([key]) => key); + }, + + get COLUMN_SIZES() { + return Object.entries(ALL_COLUMNS).reduce( + (acc, [key, def]) => ({ + ...acc, + [key]: def.size, + }), + {} + ); + }, +}; + +// Export constants maintaining the same interface +export const FILTER_PRECISIONS = FILTERS.PRECISIONS; +export const SUBMISSION_PRECISIONS = FILTERS.SUBMISSION_PRECISIONS; +export const PARAMS_RANGE = FILTERS.PARAMS_RANGE; +export const CACHE_SETTINGS = { DURATION: TIME.CACHE_DURATION }; +export const PINNED_MODELS = []; +export const DEBOUNCE_TIMINGS = TIME.DEBOUNCE; +export const ROW_SIZES = DISPLAY.ROW_SIZES; +export const SCORE_DISPLAY_OPTIONS = DISPLAY.SCORE_DISPLAY_OPTIONS; +export const RANKING_MODE_OPTIONS = DISPLAY.RANKING_MODE_OPTIONS; +export const BOOLEAN_FILTER_OPTIONS = FILTERS.BOOLEAN_OPTIONS; +export const HIGHLIGHT_FILTER_OPTIONS = FILTERS.HIGHLIGHT_OPTIONS; +export { COLUMN_DEFINITIONS }; + +// Export defaults for backward compatibility +export const TABLE_DEFAULTS = { + ROW_SIZE: "normal", + SCORE_DISPLAY: "normalized", + AVERAGE_MODE: "all", + RANKING_MODE: "static", + SEARCH: { + PRECISIONS: FILTERS.PRECISIONS, + TYPES: MODEL_TYPE_ORDER, + PARAMS_RANGE: FILTERS.PARAMS_RANGE, + }, + DEFAULT_SELECTED: { + searchValue: "", + selectedPrecisions: FILTERS.PRECISIONS, + selectedTypes: MODEL_TYPE_ORDER, + paramsRange: FILTERS.PARAMS_RANGE, + selectedBooleanFilters: [], + }, + DEBOUNCE: TIME.DEBOUNCE, + COLUMNS: COLUMN_DEFINITIONS, + PINNED_MODELS: [], + CACHE_DURATION: TIME.CACHE_DURATION, +}; + +// Highlight colors for search and table +export const HIGHLIGHT_COLORS = [ + "#1f77b4", // bleu + "#ff7f0e", // orange + "#2ca02c", // vert + "#d62728", // rouge + "#9467bd", // violet + "#8c564b", // marron + "#e377c2", // rose + "#7f7f7f", // gris + "#bcbd22", // olive + "#17becf", // cyan +]; + +// Skeleton columns widths (in pixels) +export const SKELETON_COLUMNS = [ + 40, // Checkbox + COLUMN_SIZES.RANK, // Rank + COLUMN_SIZES.TYPE_ICON, // Type icon + COLUMN_SIZES.MODEL, // Model name + COLUMN_SIZES.AVERAGE_SCORE, // Average score + COLUMN_SIZES.BENCHMARK, // Benchmark 1 + COLUMN_SIZES.BENCHMARK, // Benchmark 2 + COLUMN_SIZES.BENCHMARK, // Benchmark 3 + COLUMN_SIZES.BENCHMARK, // Benchmark 4 + COLUMN_SIZES.BENCHMARK, // Benchmark 5 + COLUMN_SIZES.BENCHMARK, // Benchmark 6 +]; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/modelTypes.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/modelTypes.js new file mode 100644 index 0000000000000000000000000000000000000000..46683b1e6d3a8b20e364260f579ce559a71e3e8b --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/modelTypes.js @@ -0,0 +1,79 @@ +export const MODEL_TYPE_ORDER = [ + 'pretrained', + 'continuously pretrained', + 'fine-tuned', + 'chat', + 'merge', + 'multimodal' +]; + +export const MODEL_TYPES = { + 'pretrained': { + icon: '🟢', + label: 'Pretrained', + description: 'Base models trained on raw text data using self-supervised learning objectives', + order: 0 + }, + 'continuously pretrained': { + icon: '🟩', + label: 'Continuously Pretrained', + description: 'Base models with extended pretraining on additional data while maintaining original architecture', + order: 1 + }, + 'fine-tuned': { + icon: '🔶', + label: 'Fine-tuned', + description: 'Models specialized through task-specific training on curated datasets', + order: 2 + }, + 'chat': { + icon: '💬', + label: 'Chat', + description: 'Models optimized for conversation using various techniques: RLHF, DPO, IFT, SFT', + order: 3 + }, + 'merge': { + icon: '🤝', + label: 'Merge', + description: 'Models created by combining weights from multiple models', + order: 4 + }, + 'multimodal': { + icon: '🌸', + label: 'Multimodal', + description: 'Models capable of processing multiple types of input', + order: 5 + } +}; + +export const getModelTypeIcon = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].icon : '❓'; +}; + +export const getModelTypeLabel = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].label : type; +}; + +export const getModelTypeDescription = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].description : 'Unknown model type'; +}; + +export const getModelTypeOrder = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].order : Infinity; +}; \ No newline at end of file diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/quickFilters.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/quickFilters.js new file mode 100644 index 0000000000000000000000000000000000000000..de74e7065becab032d996c91b858746f9e38247f --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/quickFilters.js @@ -0,0 +1,51 @@ +export const QUICK_FILTER_PRESETS = [ + { + id: 'edge_device', + label: 'For Edge Devices', + shortDescription: 'Tiny models: Up to 3B parameters', + description: 'Lightweight models optimized for edge devices with limited resources. Ideal for mobile deployment or edge computing environments.', + filters: { + paramsRange: [0, 3], + selectedBooleanFilters: ['is_for_edge_devices'] + } + }, + { + id: 'small_models', + label: 'For Consumers', + shortDescription: 'Smol-LMs: 3-7B parameters', + description: 'Lightweight models optimized for consumer hardware with up to one GPU. Ideal for private consumer hardware.', + filters: { + paramsRange: [3, 7], + selectedBooleanFilters: ['is_for_edge_devices'] + } + }, + { + id: 'medium_models', + label: 'Mid-range', + shortDescription: 'Medium-sized models: 7B-65B parameters', + description: 'Overall balance between performance and required resources.', + filters: { + paramsRange: [7, 65], + selectedBooleanFilters: [] + } + }, + { + id: 'large_models', + label: 'For the GPU-rich', + shortDescription: 'Large models: 65B+ parameters', + description: 'Large-scale models offering (in theory) the best performance but requiring significant resources. Require adapted infrastructure.', + filters: { + paramsRange: [65, 141], + selectedBooleanFilters: [] + } + }, + { + id: 'official_providers', + label: 'Only Official Providers', + shortDescription: 'Officially provided models', + description: 'Models that are officially provided and maintained by official creators or organizations.', + filters: { + selectedBooleanFilters: ['is_official_provider'] + } + } +]; \ No newline at end of file diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/tooltips.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/tooltips.js new file mode 100644 index 0000000000000000000000000000000000000000..06f311739cb6a46f0477746bf47ae59732252b44 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/tooltips.js @@ -0,0 +1,386 @@ +import { Box, Typography } from "@mui/material"; + +const createTooltipContent = (title, items) => ( + + + {title} + + + {items.map(({ label, description, subItems }, index) => ( +
  • + + {label}: {description} + {subItems && ( + + {subItems.map((item, subIndex) => ( +
  • + + {item} + +
  • + ))} +
    + )} + + + ))} +
    + +); + +export const COLUMN_TOOLTIPS = { + AVERAGE: createTooltipContent("Average score across all benchmarks:", [ + { + label: "Calculation", + description: "Weighted average of normalized scores from all benchmarks", + subItems: [ + "Each benchmark is normalized to a 0-100 scale", + "All normalised benchmarks are then averaged together", + ], + }, + ]), + + IFEVAL: createTooltipContent("Instruction-Following Evaluation (IFEval):", [ + { + label: "Purpose", + description: + "Tests model's ability to follow explicit formatting instructions", + subItems: ["Instruction following", "Formatting", "Generation"], + }, + { + label: "Scoring: Accuracy", + description: "Was the format asked for strictly respected.", + }, + ]), + + BBH: createTooltipContent("Big Bench Hard (BBH):", [ + { + label: "Overview", + description: "Collection of challenging for LLM tasks across domains, for example", + subItems: [ + "Language understanding", + "Mathematical reasoning", + "Common sense and world knowledge", + ], + }, + { + label: "Scoring: Accuracy", + description: + "Was the correct choice selected among the options.", + }, + ]), + + MATH: createTooltipContent( + "Mathematics Aptitude Test of Heuristics (MATH), level 5:", + [ + { + label: "Content", + description: "High school level competitions mathematical problems", + subItems: ["Complex algebra", "Geometry problems", "Advanced calculus"], + }, + { + label: "Scoring: Exact match", + description: + "Was the solution generated correct and in the expected format", + }, + ] + ), + + GPQA: createTooltipContent("Graduate-Level Google-Proof Q&A (GPQA):", [ + { + label: "Focus", + description: "PhD-level knowledge multiple choice questions in science", + subItems: [ + "Chemistry", + "Biology", + "Physics", + ], + }, + { + label: "Scoring: Accuracy", + description: + "Was the correct choice selected among the options.", + }, + ]), + + MUSR: createTooltipContent("Multistep Soft Reasoning (MuSR):", [ + { + label: "Scope", + description: "Reasoning and understanding on/of long texts", + subItems: [ + "Language understanding", + "Reasoning capabilities", + "Long context reasoning", + ], + }, + { + label: "Scoring: Accuracy", + description: + "Was the correct choice selected among the options.", + }, + ]), + + MMLU_PRO: createTooltipContent( + "Massive Multitask Language Understanding - Professional (MMLU-Pro):", + [ + { + label: "Coverage", + description: "Expertly reviewed multichoice questions across domains, for example:", + subItems: [ + "Medicine and healthcare", + "Law and ethics", + "Engineering", + "Mathematics", + ], + }, + { + label: "Scoring: Accuracy", + description: + "Was the correct choice selected among the options.", + }, + ] + ), + + ARCHITECTURE: createTooltipContent("Model Architecture Information:", [ + { + label: "Definition", + description: "The fundamental structure and design of the model", + subItems: [ + "Pretrained: Foundational models, initially trained on large datasets without task-specific tuning, serving as a versatile base for further development.", + "Continuously Pretrained: Base models trained with a data mix evolving as the model is trained, with the addition of specialized data during the last training steps.", + "Fine-tuned: Base models, fine-tuned on specialised domain data (legal, medical, ...), and optimized for particular tasks.", + "Chat: Models fine-tuned with IFT, RLHF, DPO, and other techniques, to handle conversational contexts effectively.", + "Merged: Combining multiple models through weights averaging or similar methods.", + "Multimodal: Models which can handle several modalities (text & image/audio/video/...). We only evaluate the text capabilities.", + ], + }, + { + label: "Impact", + description: "How architecture affects model capabilities", + subItems: [ + "Base models are expected to perform less well on instruction following evaluations, like IFEval.", + "Fine-tuned and chat models can be more verbose and more chatty than base models.", + "Merged models tend to exhibit good performance on benchmarks, which do not translate to real-world situations.", + ], + }, + ]), + + PRECISION: createTooltipContent("Numerical Precision Format:", [ + { + label: "Overview", + description: + "Data format used to store model weights and perform computations", + subItems: [ + "bfloat16: Half precision (Brain Float format), good for stability", + "float16: Half precision", + "8bit/4bit: Quantized formats, for efficiency", + "GPTQ/AWQ: Quantized methods", + ], + }, + { + label: "Impact", + description: "How precision affects model deployment", + subItems: [ + "Higher precision = better accuracy but more memory usage", + "Lower precision = faster inference and smaller size", + "Trade-off between model quality and resource usage", + ], + }, + ]), + + FLAGS: createTooltipContent("Model Flags and Special Features:", [ + { + label: "Filters", + subItems: [ + "Mixture of Expert: Uses a MoE architecture", + "Merged models: Created by averaging other models", + "Contaminated: Flagged by users from the community for (possibly accidental) cheating", + "Unavailable: No longer on the hub (private, deleted) or missing a license tag", + ], + }, + { + label: "Purpose", + description: "Why do people want to hide these models?", + subItems: [ + "Mixture of Experts: These models can be too parameter heavy", + "Merged models: Performance on benchmarks tend to be inflated compared to real life usage", + "Contaminated: Performance on benchmarks is inflated and not reflecting real life usage", + ], + }, + ]), + + PARAMETERS: createTooltipContent("Model Parameters:", [ + { + label: "Measurement", + description: "Total number of trainable parameters in billions", + subItems: [ + "Indicates model capacity and complexity", + "Correlates with computational requirements", + "Influences memory usage and inference speed", + ], + }, + ]), + + LICENSE: createTooltipContent("Model License Information:", [ + { + label: "Importance", + description: "Legal terms governing model usage and distribution", + subItems: [ + "Commercial vs non-commercial use", + "Attribution requirements", + "Modification and redistribution rights", + "Liability and warranty terms", + ], + }, + ]), + + CO2_COST: createTooltipContent("Carbon Dioxide Emissions:", [ + { + label: "What is it?", + description: "CO₂ emissions of the model evaluation ", + subItems: [ + "Only focuses on model inference for our specific setup", + "Considers data center location and energy mix", + "Allows equivalent comparision of models on our use case", + ], + }, + { + label: "Why it matters", + description: "Environmental impact of AI model training", + subItems: [ + "Large models can have significant carbon footprints", + "Helps make informed choices about model selection", + ], + }, + { + label: "Learn more", + description: + "For detailed information about our CO₂ calculation methodology, visit:", + subItems: [ + + Carbon Emissions Documentation ↗ + , + ], + }, + ]), +}; + +export const UI_TOOLTIPS = { + COLUMN_SELECTOR: "Choose which columns to display in the table", + DISPLAY_OPTIONS: createTooltipContent("Table Display Options", [ + { + label: "Overview", + description: "Configure how the table displays data and information", + subItems: [ + "Row size and layout", + "Score display format", + "Ranking calculation", + "Average score computation", + ], + }, + ]), + SEARCH_BAR: createTooltipContent("Advanced Model Search", [ + { + label: "Name Search", + description: "Search directly by model name", + subItems: [ + "Supports regular expressions (e.g., ^mistral.*7b)", + "Case sensitive", + ], + }, + { + label: "Field Search", + description: "Use @field:value syntax for precise filtering", + subItems: [ + "@architecture:llama - Filter by architecture", + "@license:mit - Filter by license", + "@precision:float16 - Filter by precision", + "@type:chat - Filter by model type", + ], + }, + { + label: "Multiple Searches", + description: "Combine multiple criteria using semicolons", + subItems: [ + "meta @license:mit; @architecture:llama", + "^mistral.*7b; @precision:float16", + ], + }, + ]), + QUICK_FILTERS: createTooltipContent( + "Filter models based on their size and applicable hardware:", + [ + { + label: "Edge devices (Up to 3BB)", + description: + "Efficient models for edge devices, optimized for blazing fast inference.", + }, + { + label: "Smol Models (3B-7B)", + description: + "Efficient models for consumer hardware, optimized for fast inference.", + }, + { + label: "Mid-range models (7B-65B)", + description: + "A bit of everything here, with overall balanced performance and resource usage around 30B.", + }, + { + label: "GPU-rich models (65B+)", + description: + "State-of-the-art performance for complex tasks, requires significant computing power.", + }, + { + label: "Official Providers", + description: + "Models directly maintained by their original creators, ensuring reliability and up-to-date performance.", + }, + ] + ), + ROW_SIZE: { + title: "Row Size", + description: + "Adjust the height of table rows. Compact is ideal for viewing more data at once, while Large provides better readability and touch targets.", + }, + SCORE_DISPLAY: { + title: "Score Display", + description: + "Choose between normalized scores (0-100% scale for easy comparison) or raw scores (actual benchmark results). Normalized scores help compare performance across different benchmarks, while raw scores show actual benchmark outputs.", + }, + RANKING_MODE: { + title: "Ranking Mode", + description: + "Choose between static ranking (original position in the full leaderboard) or dynamic ranking (position based on current filters and sorting).", + }, + AVERAGE_SCORE: { + title: "Average Score Calculation", + description: + "Define how the average score is calculated. 'All Scores' uses all benchmarks, while 'Visible Only' calculates the average using only the visible benchmark columns.", + }, +}; + +export const getTooltipStyle = {}; + +export const TABLE_TOOLTIPS = { + HUB_LINK: (modelName) => `View ${modelName} on Hugging Face Hub`, + EVAL_RESULTS: (modelName) => + `View detailed evaluation results for ${modelName}`, + POSITION_CHANGE: (change) => + `${Math.abs(change)} position${Math.abs(change) > 1 ? "s" : ""} ${ + change > 0 ? "up" : "down" + }`, + METADATA: { + TYPE: (type) => type || "-", + ARCHITECTURE: (arch) => arch || "-", + PRECISION: (precision) => precision || "-", + LICENSE: (license) => license || "-", + UPLOAD_DATE: (date) => date || "-", + SUBMISSION_DATE: (date) => date || "-", + BASE_MODEL: (model) => model || "-", + }, +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext.js new file mode 100644 index 0000000000000000000000000000000000000000..e41599900865195536321cd8ae121d0d794cb94a --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext.js @@ -0,0 +1,760 @@ +import React, { + createContext, + useContext, + useReducer, + useEffect, + useMemo, + useCallback, +} from "react"; +import { useSearchParams, useLocation } from "react-router-dom"; +import { MODEL_TYPE_ORDER } from "../constants/modelTypes"; +import { FILTER_PRECISIONS, TABLE_DEFAULTS } from "../constants/defaults"; + +// Create context +const LeaderboardContext = createContext(); + +// Define default filter values +const DEFAULT_FILTERS = { + search: "", + precisions: FILTER_PRECISIONS, + types: MODEL_TYPE_ORDER, + paramsRange: [-1, 140], + booleanFilters: [], + isOfficialProviderActive: false, +}; + +// Define default display values +const DEFAULT_DISPLAY = { + rowSize: TABLE_DEFAULTS.ROW_SIZE, + scoreDisplay: TABLE_DEFAULTS.SCORE_DISPLAY, + averageMode: TABLE_DEFAULTS.AVERAGE_MODE, + rankingMode: TABLE_DEFAULTS.RANKING_MODE, + visibleColumns: TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE, +}; + +// Create initial counter structure +const createInitialCounts = () => { + const modelTypes = {}; + MODEL_TYPE_ORDER.forEach((type) => { + modelTypes[type] = 0; + }); + + const precisions = {}; + FILTER_PRECISIONS.forEach((precision) => { + precisions[precision] = 0; + }); + + return { + modelTypes, + precisions, + officialProviders: 0, + mixtureOfExperts: 0, + flagged: 0, + merged: 0, + notOnHub: 0, + parameterRanges: { + edge: 0, + small: 0, + medium: 0, + large: 0, + }, + }; +}; + +// Define initial state +const initialState = { + models: [], + loading: true, + countsReady: false, + error: null, + filters: DEFAULT_FILTERS, + display: DEFAULT_DISPLAY, + filtersExpanded: false, + pinnedModels: [], + filterCounts: { + normal: createInitialCounts(), + officialOnly: createInitialCounts(), + }, +}; + +// Function to normalize parameter value +const normalizeParams = (params) => { + const numParams = Number(params); + if (isNaN(numParams)) return null; + return Math.round(numParams * 100) / 100; +}; + +// Function to check if a parameter count is within a range +const isInParamRange = (params, range) => { + if (range[0] === -1 && range[1] === 140) return true; + const normalizedParams = normalizeParams(params); + if (normalizedParams === null) return false; + return normalizedParams >= range[0] && normalizedParams < range[1]; +}; + +// Function to check if a model matches filter criteria +const modelMatchesFilters = (model, filters) => { + // Filter by precision + if ( + filters.precisions.length > 0 && + !filters.precisions.includes(model.model.precision) + ) { + return false; + } + + // Filter by type + if (filters.types.length > 0) { + const modelType = model.model.type?.toLowerCase().trim(); + if (!filters.types.some((type) => modelType?.includes(type))) { + return false; + } + } + + // Filter by parameters + const params = Number( + model.metadata?.params_billions || model.features?.params_billions + ); + if (!isInParamRange(params, filters.paramsRange)) return false; + + // Filter by search + if (filters.search) { + const searchLower = filters.search.toLowerCase(); + const modelName = model.model.name.toLowerCase(); + if (!modelName.includes(searchLower)) return false; + } + + // Boolean filters + if (filters.booleanFilters.length > 0) { + return filters.booleanFilters.every((filter) => { + const filterValue = typeof filter === "object" ? filter.value : filter; + + // Maintainer's Highlight keeps positive logic + if (filterValue === "is_official_provider") { + return model.features[filterValue]; + } + + // For all other filters, invert the logic + if (filterValue === "is_not_available_on_hub") { + return model.features[filterValue]; + } + + return !model.features[filterValue]; + }); + } + + return true; +}; + +// Function to calculate filtered model counts +const calculateFilteredCounts = ( + allRows, + totalPinnedCount, + filters, + filteredCount +) => { + // If no table, use raw filteredCount + if (!allRows) { + return { + currentFilteredCount: + typeof filteredCount === "number" ? filteredCount : 0, + totalPinnedCount: totalPinnedCount || 0, + }; + } + + // 1. Total number of rows (models matching filters) + const totalFilteredCount = allRows.length; + + // 2. Number of pinned models that also match filters + // These models are already included in totalFilteredCount, so we need to subtract them + // to avoid counting them twice + const pinnedMatchingFilters = allRows.filter((row) => { + const model = row.original; + return model.isPinned && modelMatchesFilters(model, filters); + }).length; + + return { + // Subtract pinned models that match filters + // as they are already displayed separately with "+X" + currentFilteredCount: totalFilteredCount - pinnedMatchingFilters, + totalPinnedCount: totalPinnedCount || 0, + }; +}; + +// Function to calculate counters +const calculateModelCounts = (models) => { + const normalCounts = createInitialCounts(); + const officialOnlyCounts = createInitialCounts(); + + models.forEach((model) => { + const isOfficial = + model.features?.is_official_provider || + model.metadata?.is_official_provider; + const countsToUpdate = [normalCounts]; + + if (isOfficial) { + countsToUpdate.push(officialOnlyCounts); + } + + countsToUpdate.forEach((counts) => { + // Model type + if (model.model?.type) { + const cleanType = model.model.type.toLowerCase().trim(); + const matchedType = MODEL_TYPE_ORDER.find((key) => + cleanType.includes(key) + ); + if (matchedType) { + counts.modelTypes[matchedType]++; + } + } + + // Precision + if (model.model?.precision) { + counts.precisions[model.model.precision]++; + } + + // Boolean filters + if ( + model.features?.is_official_provider || + model.metadata?.is_official_provider + ) + counts.officialProviders++; + if (model.features?.is_moe || model.metadata?.is_moe) + counts.mixtureOfExperts++; + if (model.features?.is_flagged || model.metadata?.is_flagged) + counts.flagged++; + if (model.features?.is_merged || model.metadata?.is_merged) + counts.merged++; + if ( + !( + model.features?.is_not_available_on_hub || + model.metadata?.is_not_available_on_hub + ) + ) + counts.notOnHub++; + + // Parameter ranges + const params = Number( + model.metadata?.params_billions || model.features?.params_billions + ); + if (!isNaN(params)) { + if (isInParamRange(params, [0, 3])) counts.parameterRanges.edge++; + if (isInParamRange(params, [3, 7])) counts.parameterRanges.small++; + if (isInParamRange(params, [7, 65])) counts.parameterRanges.medium++; + if (isInParamRange(params, [65, 141])) counts.parameterRanges.large++; + } + }); + }); + + return { + normal: normalCounts, + officialOnly: officialOnlyCounts, + }; +}; + +// Define reducer +const reducer = (state, action) => { + switch (action.type) { + case "SET_MODELS": + const newCounts = calculateModelCounts(action.payload); + return { + ...state, + models: action.payload, + filterCounts: newCounts, + countsReady: true, + loading: false, + }; + + case "SET_LOADING": + return { + ...state, + loading: action.payload, + ...(action.payload ? { countsReady: false } : {}), + }; + + case "SET_ERROR": + return { + ...state, + error: action.payload, + loading: false, + }; + + case "SET_FILTER": + return { + ...state, + filters: { + ...state.filters, + [action.key]: action.value, + }, + }; + + case "SET_DISPLAY_OPTION": + return { + ...state, + display: { + ...state.display, + [action.key]: action.value, + }, + }; + + case "TOGGLE_PINNED_MODEL": + const modelKey = action.payload; + const pinnedModels = [...state.pinnedModels]; + const modelIndex = pinnedModels.indexOf(modelKey); + + if (modelIndex === -1) { + pinnedModels.push(modelKey); + } else { + pinnedModels.splice(modelIndex, 1); + } + + return { + ...state, + pinnedModels, + }; + + case "SET_PINNED_MODELS": + return { + ...state, + pinnedModels: action.payload, + }; + + case "TOGGLE_FILTERS_EXPANDED": + return { + ...state, + filtersExpanded: !state.filtersExpanded, + }; + + case "TOGGLE_OFFICIAL_PROVIDER": + return { + ...state, + filters: { + ...state.filters, + isOfficialProviderActive: !state.filters.isOfficialProviderActive, + }, + }; + + case "RESET_FILTERS": + return { + ...state, + filters: DEFAULT_FILTERS, + }; + + case "RESET_ALL": + return { + ...state, + filters: DEFAULT_FILTERS, + display: DEFAULT_DISPLAY, + pinnedModels: [], + }; + + default: + return state; + } +}; + +// Provider component +const LeaderboardProvider = ({ children }) => { + const [state, dispatch] = useReducer(reducer, initialState); + const [searchParams, setSearchParams] = useSearchParams(); + const location = useLocation(); + + // Effect to load initial values from URL + useEffect(() => { + // Skip URL sync if we're resetting + if (location.state?.skipUrlSync) return; + + const loadFromUrl = () => { + // Load filters + const searchFromUrl = searchParams.get("search"); + if (searchFromUrl) { + dispatch({ type: "SET_FILTER", key: "search", value: searchFromUrl }); + } + + const paramsFromUrl = searchParams.get("params")?.split(",").map(Number); + if (paramsFromUrl?.length === 2) { + dispatch({ + type: "SET_FILTER", + key: "paramsRange", + value: paramsFromUrl, + }); + } + + const filtersFromUrl = + searchParams.get("filters")?.split(",").filter(Boolean) || []; + if (filtersFromUrl.length > 0) { + dispatch({ + type: "SET_FILTER", + key: "booleanFilters", + value: filtersFromUrl, + }); + } + + const precisionsFromUrl = searchParams + .get("precision") + ?.split(",") + .filter(Boolean); + if (precisionsFromUrl) { + dispatch({ + type: "SET_FILTER", + key: "precisions", + value: precisionsFromUrl, + }); + } + + const typesFromUrl = searchParams + .get("types") + ?.split(",") + .filter(Boolean); + if (typesFromUrl) { + dispatch({ type: "SET_FILTER", key: "types", value: typesFromUrl }); + } + + const officialFromUrl = searchParams.get("official") === "true"; + if (officialFromUrl) { + dispatch({ + type: "SET_FILTER", + key: "isOfficialProviderActive", + value: true, + }); + } + + // Load pinned models + const pinnedFromUrl = + searchParams.get("pinned")?.split(",").filter(Boolean) || []; + if (pinnedFromUrl.length > 0) { + dispatch({ type: "SET_PINNED_MODELS", payload: pinnedFromUrl }); + } + + // Load visible columns + const columnsFromUrl = searchParams + .get("columns") + ?.split(",") + .filter(Boolean); + if (columnsFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "visibleColumns", + value: columnsFromUrl, + }); + } + + // Load table options + const rowSizeFromUrl = searchParams.get("rowSize"); + if (rowSizeFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "rowSize", + value: rowSizeFromUrl, + }); + } + + const scoreDisplayFromUrl = searchParams.get("scoreDisplay"); + if (scoreDisplayFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "scoreDisplay", + value: scoreDisplayFromUrl, + }); + } + + const averageModeFromUrl = searchParams.get("averageMode"); + if (averageModeFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "averageMode", + value: averageModeFromUrl, + }); + } + + const rankingModeFromUrl = searchParams.get("rankingMode"); + if (rankingModeFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "rankingMode", + value: rankingModeFromUrl, + }); + } + }; + + loadFromUrl(); + }, [searchParams, location.state]); + + // Effect to synchronize filters with URL + useEffect(() => { + // Skip URL sync if we're resetting + if (location.state?.skipUrlSync) return; + + const newSearchParams = new URLSearchParams(searchParams); + const currentParams = searchParams.get("params")?.split(",").map(Number); + const currentFilters = + searchParams.get("filters")?.split(",").filter(Boolean) || []; + const currentSearch = searchParams.get("search"); + const currentPinned = + searchParams.get("pinned")?.split(",").filter(Boolean) || []; + const currentColumns = + searchParams.get("columns")?.split(",").filter(Boolean) || []; + const currentRowSize = searchParams.get("rowSize"); + const currentScoreDisplay = searchParams.get("scoreDisplay"); + const currentAverageMode = searchParams.get("averageMode"); + const currentRankingMode = searchParams.get("rankingMode"); + const currentOfficialProvider = searchParams.get("official") === "true"; + const currentPrecisions = + searchParams.get("precision")?.split(",").filter(Boolean) || []; + const currentTypes = + searchParams.get("types")?.split(",").filter(Boolean) || []; + + // Only update URL if values have changed + const paramsChanged = + !currentParams || + currentParams[0] !== state.filters.paramsRange[0] || + currentParams[1] !== state.filters.paramsRange[1]; + + const filtersChanged = + state.filters.booleanFilters.length !== currentFilters.length || + state.filters.booleanFilters.some((f) => !currentFilters.includes(f)); + + const searchChanged = state.filters.search !== currentSearch; + + const pinnedChanged = + state.pinnedModels.length !== currentPinned.length || + state.pinnedModels.some((m) => !currentPinned.includes(m)); + + const columnsChanged = + state.display.visibleColumns.length !== currentColumns.length || + state.display.visibleColumns.some((c) => !currentColumns.includes(c)); + + const rowSizeChanged = state.display.rowSize !== currentRowSize; + const scoreDisplayChanged = + state.display.scoreDisplay !== currentScoreDisplay; + const averageModeChanged = state.display.averageMode !== currentAverageMode; + const rankingModeChanged = state.display.rankingMode !== currentRankingMode; + const officialProviderChanged = + state.filters.isOfficialProviderActive !== currentOfficialProvider; + const precisionsChanged = + state.filters.precisions.length !== currentPrecisions.length || + state.filters.precisions.some((p) => !currentPrecisions.includes(p)); + const typesChanged = + state.filters.types.length !== currentTypes.length || + state.filters.types.some((t) => !currentTypes.includes(t)); + + if (paramsChanged) { + if ( + state.filters.paramsRange[0] !== -1 || + state.filters.paramsRange[1] !== 140 + ) { + newSearchParams.set("params", state.filters.paramsRange.join(",")); + } else { + newSearchParams.delete("params"); + } + } + + if (filtersChanged) { + if (state.filters.booleanFilters.length > 0) { + newSearchParams.set("filters", state.filters.booleanFilters.join(",")); + } else { + newSearchParams.delete("filters"); + } + } + + if (searchChanged) { + if (state.filters.search) { + newSearchParams.set("search", state.filters.search); + } else { + newSearchParams.delete("search"); + } + } + + if (pinnedChanged) { + if (state.pinnedModels.length > 0) { + newSearchParams.set("pinned", state.pinnedModels.join(",")); + } else { + newSearchParams.delete("pinned"); + } + } + + if (columnsChanged) { + if ( + JSON.stringify([...state.display.visibleColumns].sort()) !== + JSON.stringify([...TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE].sort()) + ) { + newSearchParams.set("columns", state.display.visibleColumns.join(",")); + } else { + newSearchParams.delete("columns"); + } + } + + if (rowSizeChanged) { + if (state.display.rowSize !== TABLE_DEFAULTS.ROW_SIZE) { + newSearchParams.set("rowSize", state.display.rowSize); + } else { + newSearchParams.delete("rowSize"); + } + } + + if (scoreDisplayChanged) { + if (state.display.scoreDisplay !== TABLE_DEFAULTS.SCORE_DISPLAY) { + newSearchParams.set("scoreDisplay", state.display.scoreDisplay); + } else { + newSearchParams.delete("scoreDisplay"); + } + } + + if (averageModeChanged) { + if (state.display.averageMode !== TABLE_DEFAULTS.AVERAGE_MODE) { + newSearchParams.set("averageMode", state.display.averageMode); + } else { + newSearchParams.delete("averageMode"); + } + } + + if (rankingModeChanged) { + if (state.display.rankingMode !== TABLE_DEFAULTS.RANKING_MODE) { + newSearchParams.set("rankingMode", state.display.rankingMode); + } else { + newSearchParams.delete("rankingMode"); + } + } + + if (officialProviderChanged) { + if (state.filters.isOfficialProviderActive) { + newSearchParams.set("official", "true"); + } else { + newSearchParams.delete("official"); + } + } + + if (precisionsChanged) { + if ( + JSON.stringify([...state.filters.precisions].sort()) !== + JSON.stringify([...FILTER_PRECISIONS].sort()) + ) { + newSearchParams.set("precision", state.filters.precisions.join(",")); + } else { + newSearchParams.delete("precision"); + } + } + + if (typesChanged) { + if ( + JSON.stringify([...state.filters.types].sort()) !== + JSON.stringify([...MODEL_TYPE_ORDER].sort()) + ) { + newSearchParams.set("types", state.filters.types.join(",")); + } else { + newSearchParams.delete("types"); + } + } + + if ( + paramsChanged || + filtersChanged || + searchChanged || + pinnedChanged || + columnsChanged || + rowSizeChanged || + scoreDisplayChanged || + averageModeChanged || + rankingModeChanged || + officialProviderChanged || + precisionsChanged || + typesChanged + ) { + // Update search params and let HashRouter handle the URL + setSearchParams(newSearchParams); + } + }, [state, searchParams, location.state]); + + const actions = useMemo( + () => ({ + setModels: (models) => dispatch({ type: "SET_MODELS", payload: models }), + setLoading: (loading) => + dispatch({ type: "SET_LOADING", payload: loading }), + setError: (error) => dispatch({ type: "SET_ERROR", payload: error }), + setFilter: (key, value) => dispatch({ type: "SET_FILTER", key, value }), + setDisplayOption: (key, value) => + dispatch({ type: "SET_DISPLAY_OPTION", key, value }), + togglePinnedModel: (modelKey) => + dispatch({ type: "TOGGLE_PINNED_MODEL", payload: modelKey }), + toggleOfficialProvider: () => + dispatch({ type: "TOGGLE_OFFICIAL_PROVIDER" }), + toggleFiltersExpanded: () => + dispatch({ type: "TOGGLE_FILTERS_EXPANDED" }), + resetFilters: () => { + dispatch({ type: "RESET_FILTERS" }); + const newParams = new URLSearchParams(searchParams); + [ + "filters", + "params", + "precision", + "types", + "official", + "search", + ].forEach((param) => { + newParams.delete(param); + }); + setSearchParams(newParams); + }, + resetAll: () => { + // Reset all state + dispatch({ type: "RESET_ALL" }); + // Clear all URL params with skipUrlSync flag + setSearchParams({}, { state: { skipUrlSync: true } }); + }, + }), + [searchParams, setSearchParams] + ); + + // Function to calculate counts (exposed via context) + const getFilteredCounts = useCallback( + (allRows, totalPinnedCount, filteredCount) => { + return calculateFilteredCounts( + allRows, + totalPinnedCount, + state.filters, + filteredCount + ); + }, + [state.filters] + ); + + // Also expose filtering function for reuse elsewhere + const checkModelMatchesFilters = useCallback( + (model) => { + return modelMatchesFilters(model, state.filters); + }, + [state.filters] + ); + + const value = useMemo( + () => ({ + state: { + ...state, + loading: state.loading || !state.countsReady, + }, + actions, + utils: { + getFilteredCounts, + checkModelMatchesFilters, + }, + }), + [state, actions, getFilteredCounts, checkModelMatchesFilters] + ); + + return ( + + {children} + + ); +}; + +// Hook to use context +const useLeaderboard = () => { + const context = useContext(LeaderboardContext); + if (!context) { + throw new Error("useLeaderboard must be used within a LeaderboardProvider"); + } + return context; +}; + +export { useLeaderboard }; +export default LeaderboardProvider; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useBatchedState.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useBatchedState.js new file mode 100644 index 0000000000000000000000000000000000000000..ad11c91393ca9e413853ae440154b948293103e9 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useBatchedState.js @@ -0,0 +1,31 @@ +import { useState, useCallback, useTransition } from 'react'; + +export const useBatchedState = (initialState, options = {}) => { + const { batchDelay = 0, useTransitions = false } = options; + const [state, setState] = useState(typeof initialState === 'function' ? initialState() : initialState); + const [isPending, startTransition] = useTransition(); + + const setBatchedState = useCallback((newState) => { + if (useTransitions) { + startTransition(() => { + if (batchDelay > 0) { + setTimeout(() => { + setState(newState); + }, batchDelay); + } else { + setState(newState); + } + }); + } else { + if (batchDelay > 0) { + setTimeout(() => { + setState(newState); + }, batchDelay); + } else { + setState(newState); + } + } + }, [batchDelay, useTransitions]); + + return [state, setBatchedState, isPending]; +}; \ No newline at end of file diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useDataUtils.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useDataUtils.js new file mode 100644 index 0000000000000000000000000000000000000000..5313812d5341be56ff4539c6be0cc1f98efcd2d4 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useDataUtils.js @@ -0,0 +1,310 @@ +import { useMemo } from "react"; +import { + looksLikeRegex, + parseSearchQuery, + getValueByPath, +} from "../utils/searchUtils"; +import { MODEL_TYPE_ORDER } from "../constants/modelTypes"; + +// Calculate min/max averages +export const useAverageRange = (data) => { + return useMemo(() => { + const averages = data.map((item) => item.model.average_score); + return { + minAverage: Math.min(...averages), + maxAverage: Math.max(...averages), + }; + }, [data]); +}; + +// Generate colors for scores +export const useColorGenerator = (minAverage, maxAverage) => { + return useMemo(() => { + const colorCache = new Map(); + return (value) => { + const cached = colorCache.get(value); + if (cached) return cached; + + const normalizedValue = (value - minAverage) / (maxAverage - minAverage); + const red = Math.round(255 * (1 - normalizedValue) * 1); + const green = Math.round(255 * normalizedValue) * 1; + const color = `rgba(${red}, ${green}, 0, 1)`; + colorCache.set(value, color); + return color; + }; + }, [minAverage, maxAverage]); +}; + +// Process data with boolean standardization +export const useProcessedData = (data, averageMode, visibleColumns) => { + return useMemo(() => { + let processed = data.map((item) => { + const evaluationScores = Object.entries(item.evaluations) + .filter(([key]) => { + if (averageMode === "all") return true; + return visibleColumns.includes(`evaluations.${key}.normalized_score`); + }) + .map(([, value]) => value.normalized_score); + + const average = + evaluationScores.length > 0 + ? evaluationScores.reduce((a, b) => a + b, 0) / + evaluationScores.length + : averageMode === "visible" + ? null + : 0; + + // Boolean standardization + const standardizedFeatures = { + ...item.features, + is_moe: Boolean(item.features.is_moe), + is_flagged: Boolean(item.features.is_flagged), + is_official_provider: Boolean(item.features.is_official_provider), + is_merged: Boolean(item.features.is_merged), + is_not_available_on_hub: Boolean(item.features.is_not_available_on_hub), + }; + + return { + ...item, + features: standardizedFeatures, + model: { + ...item.model, + has_chat_template: Boolean(item.model.has_chat_template), + average_score: average, + }, + }; + }); + + processed.sort((a, b) => { + if (a.model.average_score === null && b.model.average_score === null) + return 0; + if (a.model.average_score === null) return 1; + if (b.model.average_score === null) return -1; + return b.model.average_score - a.model.average_score; + }); + + const result = processed.map((item, index) => ({ + ...item, + static_rank: index + 1, + })); + + return result; + }, [data, averageMode, visibleColumns]); +}; + +// Common filtering logic +export const useFilteredData = ( + processedData, + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels = [], + isOfficialProviderActive = false +) => { + return useMemo(() => { + const pinnedData = processedData.filter((row) => { + return pinnedModels.includes(row.id); + }); + + const unpinnedData = processedData.filter((row) => { + return !pinnedModels.includes(row.id); + }); + + let filteredUnpinned = unpinnedData; + + // Filter by official providers + if (isOfficialProviderActive) { + filteredUnpinned = filteredUnpinned.filter( + (row) => + row.features?.is_official_provider || + row.metadata?.is_official_provider + ); + } + + // Filter by precision + if (selectedPrecisions.length > 0) { + filteredUnpinned = filteredUnpinned.filter((row) => + selectedPrecisions.includes(row.model.precision) + ); + } + + // Filter by type + if ( + selectedTypes.length > 0 && + selectedTypes.length < MODEL_TYPE_ORDER.length + ) { + filteredUnpinned = filteredUnpinned.filter((row) => { + const modelType = row.model.type?.toLowerCase().trim(); + return selectedTypes.some((type) => modelType?.includes(type)); + }); + } + + // Filter by parameters + if (!(paramsRange[0] === -1 && paramsRange[1] === 140)) { + filteredUnpinned = filteredUnpinned.filter((row) => { + const params = + row.metadata?.params_billions || row.features?.params_billions; + if (params === undefined || params === null) return false; + return params >= paramsRange[0] && params < paramsRange[1]; + }); + } + + // Filter by search + if (searchValue) { + const searchQueries = searchValue + .split(";") + .map((q) => q.trim()) + .filter((q) => q); + if (searchQueries.length > 0) { + filteredUnpinned = filteredUnpinned.filter((row) => { + return searchQueries.some((query) => { + const { specialSearches, textSearch } = parseSearchQuery(query); + + const specialSearchMatch = specialSearches.every( + ({ field, value }) => { + const fieldValue = getValueByPath(row, field) + ?.toString() + .toLowerCase(); + return fieldValue?.includes(value.toLowerCase()); + } + ); + + if (!specialSearchMatch) return false; + if (!textSearch) return true; + + const modelName = row.model.name.toLowerCase(); + const searchLower = textSearch.toLowerCase(); + + if (looksLikeRegex(textSearch)) { + try { + const regex = new RegExp(textSearch, "i"); + return regex.test(modelName); + } catch (e) { + return modelName.includes(searchLower); + } + } else { + return modelName.includes(searchLower); + } + }); + }); + } + } + + // Filter by booleans + if (selectedBooleanFilters.length > 0) { + filteredUnpinned = filteredUnpinned.filter((row) => { + return selectedBooleanFilters.every((filter) => { + const filterValue = + typeof filter === "object" ? filter.value : filter; + + // Maintainer's Highlight keeps positive logic + if (filterValue === "is_official_provider") { + return row.features[filterValue]; + } + + // For all other filters, invert the logic + if (filterValue === "is_not_available_on_hub") { + return row.features[filterValue]; + } + + return !row.features[filterValue]; + }); + }); + } + + // Create ordered array of pinned models respecting pinnedModels order + const orderedPinnedData = pinnedModels + .map((pinnedModelId) => + pinnedData.find((item) => item.id === pinnedModelId) + ) + .filter(Boolean); + + // Combine all filtered data + const allFilteredData = [...filteredUnpinned, ...orderedPinnedData]; + + // Sort all data by average_score for dynamic_rank + const sortedByScore = [...allFilteredData].sort((a, b) => { + // Si les scores moyens sont différents, trier par score + if (a.model.average_score !== b.model.average_score) { + if (a.model.average_score === null && b.model.average_score === null) + return 0; + if (a.model.average_score === null) return 1; + if (b.model.average_score === null) return -1; + return b.model.average_score - a.model.average_score; + } + + // Si les scores sont égaux, comparer le nom du modèle et la date de soumission + if (a.model.name === b.model.name) { + // Si même nom, trier par date de soumission (la plus récente d'abord) + const dateA = new Date(a.metadata?.submission_date || 0); + const dateB = new Date(b.metadata?.submission_date || 0); + return dateB - dateA; + } + + // Si noms différents, trier par nom + return a.model.name.localeCompare(b.model.name); + }); + + // Create Map to store dynamic_ranks + const dynamicRankMap = new Map(); + sortedByScore.forEach((item, index) => { + dynamicRankMap.set(item.id, index + 1); + }); + + // Add ranks to final data + const finalData = [...orderedPinnedData, ...filteredUnpinned].map( + (item) => { + return { + ...item, + dynamic_rank: dynamicRankMap.get(item.id), + rank: item.isPinned + ? pinnedModels.indexOf(item.id) + 1 + : rankingMode === "static" + ? item.static_rank + : dynamicRankMap.get(item.id), + isPinned: pinnedModels.includes(item.id), + }; + } + ); + + return finalData; + }, [ + processedData, + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels, + isOfficialProviderActive, + ]); +}; + +// Column visibility management +export const useColumnVisibility = (visibleColumns = []) => { + // Create secure visibility object + const columnVisibility = useMemo(() => { + // Check visible columns + const safeVisibleColumns = Array.isArray(visibleColumns) + ? visibleColumns + : []; + + const visibility = {}; + try { + safeVisibleColumns.forEach((columnKey) => { + if (typeof columnKey === "string") { + visibility[columnKey] = true; + } + }); + } catch (error) { + console.warn("Error in useColumnVisibility:", error); + } + return visibility; + }, [visibleColumns]); + + return columnVisibility; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData.js new file mode 100644 index 0000000000000000000000000000000000000000..8642fe72c7aa443ac6b6cae2637b24a3b884432b --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData.js @@ -0,0 +1,127 @@ +import { useMemo, useRef, useState } from "react"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useSearchParams } from "react-router-dom"; +import { useLeaderboard } from "../context/LeaderboardContext"; +import { useDataProcessing } from "../components/Table/hooks/useDataProcessing"; + +export const useLeaderboardData = () => { + const queryClient = useQueryClient(); + const [searchParams] = useSearchParams(); + const isInitialLoadRef = useRef(true); + + const { data, isLoading, error } = useQuery({ + queryKey: ["leaderboard"], + queryFn: async () => { + console.log("🔄 Starting API fetch attempt..."); + try { + console.log("🌐 Fetching from API..."); + const response = await fetch("/api/leaderboard/formatted"); + console.log("📡 API Response status:", response.status); + + if (!response.ok) { + const errorText = await response.text(); + console.error("🚨 API Error:", { + status: response.status, + statusText: response.statusText, + body: errorText, + }); + throw new Error(`HTTP error! status: ${response.status}`); + } + + const newData = await response.json(); + console.log("📥 Received data size:", JSON.stringify(newData).length); + return newData; + } catch (error) { + console.error("🔥 Detailed error:", { + name: error.name, + message: error.message, + stack: error.stack, + }); + throw error; + } + }, + refetchOnWindowFocus: false, + enabled: isInitialLoadRef.current || !!searchParams.toString(), + }); + + useMemo(() => { + if (data && isInitialLoadRef.current) { + console.log("🎯 Initial load complete"); + isInitialLoadRef.current = false; + } + }, [data]); + + return { + data, + isLoading, + error, + refetch: () => queryClient.invalidateQueries(["leaderboard"]), + }; +}; + +export const useLeaderboardProcessing = () => { + const { state, actions } = useLeaderboard(); + const [sorting, setSorting] = useState([ + { id: "model.average_score", desc: true }, + ]); + + const memoizedData = useMemo(() => state.models, [state.models]); + const memoizedFilters = useMemo( + () => ({ + search: state.filters.search, + precisions: state.filters.precisions, + types: state.filters.types, + paramsRange: state.filters.paramsRange, + booleanFilters: state.filters.booleanFilters, + isOfficialProviderActive: state.filters.isOfficialProviderActive, + }), + [ + state.filters.search, + state.filters.precisions, + state.filters.types, + state.filters.paramsRange, + state.filters.booleanFilters, + state.filters.isOfficialProviderActive, + ] + ); + + const { + table, + minAverage, + maxAverage, + getColorForValue, + processedData, + filteredData, + columns, + columnVisibility, + } = useDataProcessing( + memoizedData, + memoizedFilters.search, + memoizedFilters.precisions, + memoizedFilters.types, + memoizedFilters.paramsRange, + memoizedFilters.booleanFilters, + sorting, + state.display.rankingMode, + state.display.averageMode, + state.display.visibleColumns, + state.display.scoreDisplay, + state.pinnedModels, + actions.togglePinnedModel, + setSorting, + memoizedFilters.isOfficialProviderActive + ); + + return { + table, + minAverage, + maxAverage, + getColorForValue, + processedData, + filteredData, + columns, + columnVisibility, + loading: state.loading, + error: state.error, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/styles/common.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/styles/common.js new file mode 100644 index 0000000000000000000000000000000000000000..06648e526979fd7c992ea3f3721468e261448593 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/styles/common.js @@ -0,0 +1,153 @@ +import { alpha } from "@mui/material"; + +export const commonStyles = { + // Tooltips + tooltip: { + sx: { + bgcolor: "background.tooltip", + "& .MuiTooltip-arrow": { + color: "background.tooltip", + }, + padding: "12px 16px", + maxWidth: 300, + fontSize: "0.875rem", + lineHeight: 1.4, + }, + }, + + // Progress bars + progressBar: { + position: "absolute", + left: -16, + top: -8, + height: "calc(100% + 16px)", + opacity: (theme) => (theme.palette.mode === "light" ? 0.1 : 0.2), + transition: "width 0.3s ease", + zIndex: 0, + }, + + // Cell containers + cellContainer: { + display: "flex", + alignItems: "center", + height: "100%", + width: "100%", + position: "relative", + }, + + // Hover effects + hoverEffect: (theme, isActive = false) => ({ + backgroundColor: isActive + ? alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.08 : 0.16 + ) + : theme.palette.action.hover, + "& .MuiTypography-root": { + color: isActive ? "primary.main" : "text.primary", + }, + "& .MuiSvgIcon-root": { + color: isActive ? "primary.main" : "text.primary", + }, + }), + + // Filter groups + filterGroup: { + title: { + mb: 1, + fontSize: "0.8rem", + fontWeight: 700, + color: "text.primary", + display: "flex", + alignItems: "center", + gap: 0.5, + }, + container: { + display: "flex", + flexWrap: "wrap", + gap: 0.5, + alignItems: "center", + }, + }, + + // Option buttons (like in DisplayOptions) + optionButton: { + display: "flex", + alignItems: "center", + gap: 0.8, + cursor: "pointer", + padding: "4px 10px", + borderRadius: 1, + height: "32px", + "& .MuiSvgIcon-root": { + fontSize: "0.9rem", + }, + "& .MuiTypography-root": { + fontSize: "0.85rem", + }, + }, + + // Score indicators + scoreIndicator: { + dot: { + width: 10, + height: 10, + borderRadius: "50%", + marginLeft: -1, + }, + bar: { + position: "absolute", + left: -16, + top: -8, + height: "calc(100% + 16px)", + opacity: (theme) => (theme.palette.mode === "light" ? 0.1 : 0.2), + transition: "width 0.3s ease", + }, + }, + + // Popover content + popoverContent: { + p: 3, + width: 280, + maxHeight: 400, + overflowY: "auto", + }, +}; + +// Composant styles +export const componentStyles = { + // Table header cell + headerCell: { + borderRight: (theme) => + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + "&:last-child": { + borderRight: "none", + }, + whiteSpace: "nowrap", + overflow: "hidden", + textOverflow: "ellipsis", + padding: "8px 16px", + backgroundColor: (theme) => theme.palette.background.paper, + position: "sticky !important", + top: 0, + zIndex: 10, + }, + + // Table cell + tableCell: { + borderRight: (theme) => + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + "&:last-child": { + borderRight: "none", + }, + whiteSpace: "nowrap", + overflow: "hidden", + textOverflow: "ellipsis", + }, +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/columnUtils.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/columnUtils.js new file mode 100644 index 0000000000000000000000000000000000000000..526c015c6684569bc6581e1931b28e7dd09b3bac --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/columnUtils.js @@ -0,0 +1,1073 @@ +import React from "react"; +import { Box, Typography, Link, Tooltip, IconButton } from "@mui/material"; +import { getModelTypeIcon } from "../constants/modelTypes"; +import TrendingUpIcon from "@mui/icons-material/TrendingUp"; +import TrendingDownIcon from "@mui/icons-material/TrendingDown"; +import RemoveIcon from "@mui/icons-material/Remove"; +import PushPinIcon from "@mui/icons-material/PushPin"; +import PushPinOutlinedIcon from "@mui/icons-material/PushPinOutlined"; +import { TABLE_DEFAULTS, HIGHLIGHT_COLORS } from "../constants/defaults"; +import { looksLikeRegex, extractTextSearch } from "./searchUtils"; +import { commonStyles } from "../styles/common"; +import { typeColumnSort } from "../components/Table/hooks/useSorting"; +import { + COLUMN_TOOLTIPS, + getTooltipStyle, + TABLE_TOOLTIPS, +} from "../constants/tooltips"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import { alpha } from "@mui/material/styles"; +import InfoIconWithTooltip from "../../../../../components/shared/InfoIconWithTooltip"; + +const DatabaseIcon = () => ( + +); + +const HighlightedText = ({ text, searchValue }) => { + if (!searchValue) return text; + + const searches = searchValue + .split(";") + .map((s) => s.trim()) + .filter(Boolean); + let result = text; + let fragments = [{ text: result, isMatch: false }]; + + searches.forEach((search, searchIndex) => { + if (!search) return; + + try { + let regex; + if (looksLikeRegex(search)) { + regex = new RegExp(search, "gi"); + } else { + regex = new RegExp(search.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "gi"); + } + + const newFragments = []; + fragments.forEach((fragment) => { + if (fragment.isMatch) { + newFragments.push(fragment); + return; + } + + const parts = fragment.text.split(regex); + const matches = fragment.text.match(regex); + + if (!matches) { + newFragments.push(fragment); + return; + } + + parts.forEach((part, i) => { + if (part) newFragments.push({ text: part, isMatch: false }); + if (i < parts.length - 1) { + newFragments.push({ + text: matches[i], + isMatch: true, + colorIndex: searchIndex % HIGHLIGHT_COLORS.length, + }); + } + }); + }); + + fragments = newFragments; + } catch (e) { + console.warn("Invalid regex:", search); + } + }); + + return ( + <> + {fragments.map((fragment, i) => + fragment.isMatch ? ( + + theme.palette.getContrastText( + HIGHLIGHT_COLORS[fragment.colorIndex] + ), + fontWeight: 500, + px: 0.5, + py: "2px", + borderRadius: "3px", + mx: "1px", + overflow: "visible", + display: "inline-block", + }} + > + {fragment.text} + + ) : ( + {fragment.text} + ) + )} + + ); +}; + +const MEDAL_STYLES = { + 1: { + color: "#B58A1B", + background: "linear-gradient(135deg, #FFF7E0 0%, #FFD700 100%)", + borderColor: "rgba(212, 160, 23, 0.35)", + shadowColor: "rgba(212, 160, 23, 0.8)", + }, + 2: { + color: "#667380", + background: "linear-gradient(135deg, #FFFFFF 0%, #D8E3ED 100%)", + borderColor: "rgba(124, 139, 153, 0.35)", + shadowColor: "rgba(124, 139, 153, 0.8)", + }, + 3: { + color: "#B85C2F", + background: "linear-gradient(135deg, #FDF0E9 0%, #FFBC8C 100%)", + borderColor: "rgba(204, 108, 61, 0.35)", + shadowColor: "rgba(204, 108, 61, 0.8)", + }, +}; + +const getMedalStyle = (rank) => { + if (rank <= 3) { + const medalStyle = MEDAL_STYLES[rank]; + return { + color: medalStyle.color, + fontWeight: 900, + fontStretch: "150%", + fontFamily: '"Inter", -apple-system, sans-serif', + width: "24px", + height: "24px", + background: medalStyle.background, + border: "1px solid", + borderColor: medalStyle.borderColor, + borderRadius: "50%", + display: "flex", + alignItems: "center", + justifyContent: "center", + fontSize: "0.95rem", + lineHeight: 1, + padding: 0, + boxShadow: `1px 1px 0 ${medalStyle.shadowColor}`, + position: "relative", + }; + } + return { + color: "inherit", + fontWeight: rank <= 10 ? 600 : 400, + }; +}; + +const getRankStyle = (rank) => getMedalStyle(rank); + +const RankIndicator = ({ rank, previousRank, mode }) => { + const rankChange = previousRank ? previousRank - rank : 0; + + const RankChangeIndicator = ({ change }) => { + if (!change || mode === "dynamic") return null; + + const getChangeColor = (change) => { + if (change > 0) return "success.main"; + if (change < 0) return "error.main"; + return "grey.500"; + }; + + const getChangeIcon = (change) => { + if (change > 0) return ; + if (change < 0) return ; + return ; + }; + + return ( + 1 ? "s" : "" + } ${change > 0 ? "up" : "down"}`} + arrow + placement="right" + > + + {getChangeIcon(change)} + + + ); + }; + + return ( + + + {rank <= 3 ? ( + <> + + {rank} + + + + ) : ( + <> + + {rank} + + + + )} + + + ); +}; + +const getDetailsUrl = (modelName) => { + const formattedName = modelName.replace("/", "__"); + return `https://huggingface.co/datasets/open-llm-leaderboard/${formattedName}-details`; +}; + +const HeaderLabel = ({ label, tooltip, className, isSorted }) => ( + + + {label} + + +); + +const InfoIcon = ({ tooltip }) => ( + + + +); + +const createHeaderCell = (label, tooltip) => (header) => + ( + + + + + {tooltip && } + + + ); + +const createModelHeader = + (totalModels, officialProvidersCount = 0, isOfficialProviderActive = false) => + ({ table }) => { + return ( + + + + Model + + + + ); + }; + +const BooleanValue = ({ value }) => { + if (value === null || value === undefined) + return -; + + return ( + ({ + display: "flex", + alignItems: "center", + justifyContent: "center", + borderRadius: "4px", + px: 1, + py: 0.5, + backgroundColor: value + ? theme.palette.mode === "dark" + ? alpha(theme.palette.success.main, 0.1) + : alpha(theme.palette.success.main, 0.1) + : theme.palette.mode === "dark" + ? alpha(theme.palette.error.main, 0.1) + : alpha(theme.palette.error.main, 0.1), + })} + > + ({ + color: value + ? theme.palette.mode === "dark" + ? theme.palette.success.light + : theme.palette.success.dark + : theme.palette.mode === "dark" + ? theme.palette.error.light + : theme.palette.error.dark, + })} + > + {value ? "Yes" : "No"} + + + ); +}; + +export const createColumns = ( + getColorForValue, + scoreDisplay = "normalized", + columnVisibility = {}, + totalModels, + averageMode = "all", + searchValue = "", + rankingMode = "static", + onTogglePin, + hasPinnedRows = false +) => { + // Ajuster les tailles des colonnes en fonction de la présence de lignes épinglées + const getColumnSize = (defaultSize) => + hasPinnedRows ? "auto" : `${defaultSize}px`; + + const baseColumns = [ + { + accessorKey: "isPinned", + header: () => null, + cell: ({ row }) => ( + + { + e.stopPropagation(); + e.preventDefault(); + onTogglePin(row.original.id); + }} + sx={{ + padding: 0.5, + color: row.original.isPinned ? "primary.main" : "grey.400", + "&:hover": { + color: "primary.main", + }, + }} + > + {row.original.isPinned ? ( + + ) : ( + + )} + + + ), + enableSorting: false, + size: getColumnSize(40), + }, + { + accessorKey: "rank", + header: createHeaderCell("Rank"), + cell: ({ row }) => { + const rank = + rankingMode === "static" + ? row.original.static_rank + : row.original.dynamic_rank; + + return ( + + ); + }, + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["rank"], + }, + { + id: "model_type", + accessorFn: (row) => row.model.type, + header: createHeaderCell("Type"), + sortingFn: typeColumnSort, + cell: ({ row }) => ( + + + + {getModelTypeIcon(row.original.model.type)} + + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.type_icon"], + }, + { + accessorKey: "id", + header: createModelHeader(totalModels), + cell: ({ row }) => { + const textSearch = extractTextSearch(searchValue); + const modelName = row.original.model.name; + + return ( + + + + theme.palette.mode === "dark" + ? theme.palette.info.light + : theme.palette.info.dark, + "& svg": { + opacity: 0.8, + }, + }, + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + flex: 1, + minWidth: 0, + fontWeight: row.original.static_rank <= 3 ? 600 : "inherit", + }} + > + + + + + + + + + ); + }, + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["id"], + }, + { + accessorKey: "model.average_score", + header: createHeaderCell("Average", COLUMN_TOOLTIPS.AVERAGE), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "model.average_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.average_score"], + meta: { + headerStyle: { + borderLeft: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + borderRight: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + }, + cellStyle: (value) => ({ + position: "relative", + overflow: "hidden", + padding: "8px 16px", + borderLeft: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + borderRight: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + }), + }, + }, + ]; + const createScoreCell = (getValue, row, field) => { + const value = getValue(); + const rawValue = field.includes("normalized") + ? row.original.evaluations[field.split(".")[1]]?.value + : value; + + const isAverageColumn = field === "model.average_score"; + const hasNoValue = value === null || value === undefined; + + return ( + + {!hasNoValue && (scoreDisplay === "normalized" || isAverageColumn) && ( + (theme.palette.mode === "light" ? 0.1 : 0.2), + transition: "width 0.3s ease", + zIndex: 0, + }} + /> + )} + + {isAverageColumn && !hasNoValue && ( + + )} + + {hasNoValue ? ( + "-" + ) : ( + <> + {isAverageColumn ? ( + <> + {value.toFixed(2)} + % + + ) : scoreDisplay === "normalized" ? ( + <> + {value.toFixed(2)} + % + + ) : ( + <>{rawValue.toFixed(2)} + )} + + )} + + + + ); + }; + + const evaluationColumns = [ + { + accessorKey: "evaluations.ifeval.normalized_score", + header: createHeaderCell("IFEval", COLUMN_TOOLTIPS.IFEVAL), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.ifeval.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.ifeval.normalized_score" + ], + }, + { + accessorKey: "evaluations.bbh.normalized_score", + header: createHeaderCell("BBH", COLUMN_TOOLTIPS.BBH), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.bbh.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.bbh.normalized_score" + ], + }, + { + accessorKey: "evaluations.math.normalized_score", + header: createHeaderCell("MATH", COLUMN_TOOLTIPS.MATH), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.math.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.math.normalized_score" + ], + }, + { + accessorKey: "evaluations.gpqa.normalized_score", + header: createHeaderCell("GPQA", COLUMN_TOOLTIPS.GPQA), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.gpqa.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.gpqa.normalized_score" + ], + }, + { + accessorKey: "evaluations.musr.normalized_score", + header: createHeaderCell("MUSR", COLUMN_TOOLTIPS.MUSR), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.musr.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.musr.normalized_score" + ], + }, + { + accessorKey: "evaluations.mmlu_pro.normalized_score", + header: createHeaderCell("MMLU-PRO", COLUMN_TOOLTIPS.MMLU_PRO), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.mmlu_pro.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.mmlu_pro.normalized_score" + ], + }, + ]; + + const optionalColumns = [ + { + accessorKey: "model.architecture", + header: createHeaderCell("Architecture", COLUMN_TOOLTIPS.ARCHITECTURE), + accessorFn: (row) => row.model.architecture, + cell: ({ row }) => ( + + {row.original.model.architecture || "-"} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.architecture"], + }, + { + accessorKey: "model.precision", + header: createHeaderCell("Precision", COLUMN_TOOLTIPS.PRECISION), + accessorFn: (row) => row.model.precision, + cell: ({ row }) => ( + + {row.original.model.precision || "-"} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.precision"], + }, + { + accessorKey: "metadata.params_billions", + header: createHeaderCell("Parameters", COLUMN_TOOLTIPS.PARAMETERS), + cell: ({ row }) => ( + + + {row.original.metadata.params_billions} + B + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.params_billions"], + }, + { + accessorKey: "metadata.hub_license", + header: createHeaderCell("License", COLUMN_TOOLTIPS.LICENSE), + cell: ({ row }) => ( + + + {row.original.metadata.hub_license || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.hub_license"], + }, + { + accessorKey: "metadata.hub_hearts", + header: createHeaderCell( + "Hub ❤️", + "Number of likes received on the Hugging Face Hub" + ), + cell: ({ row }) => ( + + {row.original.metadata.hub_hearts} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.hub_hearts"], + }, + { + accessorKey: "metadata.upload_date", + header: createHeaderCell( + "Upload Date", + "Date when the model was uploaded to the Hugging Face Hub" + ), + cell: ({ row }) => ( + + + {row.original.metadata.upload_date || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.upload_date"], + }, + { + accessorKey: "metadata.submission_date", + header: createHeaderCell( + "Submission Date", + "Date when the model was submitted to the leaderboard" + ), + cell: ({ row }) => ( + + + {row.original.metadata.submission_date || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.submission_date"], + }, + { + accessorKey: "metadata.generation", + header: createHeaderCell( + "Generation", + "The generation or version number of the model" + ), + cell: ({ row }) => ( + + {row.original.metadata.generation} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.generation"], + }, + { + accessorKey: "metadata.base_model", + header: createHeaderCell( + "Base Model", + "The original model this model was derived from" + ), + cell: ({ row }) => ( + + + {row.original.metadata.base_model || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.base_model"], + }, + { + accessorKey: "metadata.co2_cost", + header: createHeaderCell("CO₂ Cost", COLUMN_TOOLTIPS.CO2_COST), + cell: ({ row }) => ( + + + {row.original.metadata.co2_cost?.toFixed(2) || "0"} + kg + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.co2_cost"], + }, + { + accessorKey: "model.has_chat_template", + header: createHeaderCell( + "Chat Template", + "Whether this model has a chat template defined" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.has_chat_template"], + }, + { + accessorKey: "features.is_not_available_on_hub", + header: createHeaderCell( + "Hub Availability", + "Whether the model is available on the Hugging Face Hub" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "features.is_not_available_on_hub" + ], + }, + { + accessorKey: "features.is_official_provider", + header: createHeaderCell( + "Official Providers", + "Models that are officially provided and maintained by their original creators or organizations" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "features.is_official_provider" + ], + enableSorting: true, + }, + { + accessorKey: "features.is_moe", + header: createHeaderCell( + "Mixture of Experts", + "Whether this model uses a Mixture of Experts architecture" + ), + cell: ({ row }) => , + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["features.is_moe"], + }, + { + accessorKey: "features.is_flagged", + header: createHeaderCell( + "Flag Status", + "Whether this model has been flagged for any issues" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["features.is_flagged"], + }, + ]; + + // Utiliser directement columnVisibility + const finalColumns = [ + ...baseColumns, + ...evaluationColumns.filter((col) => columnVisibility[col.accessorKey]), + ...optionalColumns + .filter((col) => columnVisibility[col.accessorKey]) + .sort((a, b) => { + // Définir l'ordre personnalisé des colonnes + const order = { + "model.architecture": 1, + "model.precision": 2, + "metadata.params_billions": 3, + "metadata.hub_license": 4, + "metadata.co2_cost": 5, + "metadata.hub_hearts": 6, + "metadata.upload_date": 7, + "metadata.submission_date": 8, + "metadata.generation": 9, + "metadata.base_model": 10, + "model.has_chat_template": 11, + "features.is_not_available_on_hub": 12, + "features.is_official_provider": 13, + "features.is_moe": 14, + "features.is_flagged": 15, + }; + return order[a.accessorKey] - order[b.accessorKey]; + }), + ]; + + return finalColumns; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/searchUtils.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/searchUtils.js new file mode 100644 index 0000000000000000000000000000000000000000..091796b7a7a3721b4d7f790f0fda75ca151a838d --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/searchUtils.js @@ -0,0 +1,92 @@ +// Utility function to detect if a string looks like a regex +export const looksLikeRegex = (str) => { + const regexSpecialChars = /[\\^$.*+?()[\]{}|]/; + return regexSpecialChars.test(str); +}; + +// Function to map search fields to correct paths +const getFieldPath = (field) => { + const fieldMappings = { + precision: "model.precision", + architecture: "model.architecture", + license: "metadata.hub_license", + type: "model.type", + }; + return fieldMappings[field] || field; +}; + +// Function to extract special searches and normal text +export const parseSearchQuery = (query) => { + const specialSearches = []; + let remainingText = query; + + // Look for all @field:value patterns + const prefixRegex = /@\w+:/g; + const matches = query.match(prefixRegex) || []; + + matches.forEach((prefix) => { + const regex = new RegExp(`${prefix}([^\\s@]+)`, "g"); + remainingText = remainingText.replace(regex, (match, value) => { + const field = prefix.slice(1, -1); + specialSearches.push({ + field: getFieldPath(field), + displayField: field, + value, + }); + return ""; + }); + }); + + return { + specialSearches, + textSearch: remainingText.trim(), + }; +}; + +// Function to extract simple text search +export const extractTextSearch = (searchValue) => { + return searchValue + .split(";") + .map((query) => { + const { textSearch } = parseSearchQuery(query); + return textSearch; + }) + .filter(Boolean) + .join(";"); +}; + +// Utility function to access nested object properties +export const getValueByPath = (obj, path) => { + return path.split(".").reduce((acc, part) => acc?.[part], obj); +}; + +// Function to generate natural language description of the search +export const generateSearchDescription = (searchValue) => { + if (!searchValue) return null; + + const searchGroups = searchValue + .split(";") + .map((group) => group.trim()) + .filter(Boolean); + + return searchGroups.map((group, index) => { + const { specialSearches, textSearch } = parseSearchQuery(group); + + let parts = []; + if (textSearch) { + parts.push(textSearch); + } + + if (specialSearches.length > 0) { + const specialParts = specialSearches.map( + ({ displayField, value }) => `@${displayField}:${value}` + ); + parts = parts.concat(specialParts); + } + + return { + text: parts.join(" "), + index, + }; + }); +}; diff --git a/frontend/src/pages/QuotePage/QuotePage.js b/frontend/src/pages/QuotePage/QuotePage.js new file mode 100644 index 0000000000000000000000000000000000000000..a8909e6b83290ac9578a63f4a7baf1f25df5e175 --- /dev/null +++ b/frontend/src/pages/QuotePage/QuotePage.js @@ -0,0 +1,278 @@ +import React from "react"; +import { + Box, + Typography, + Paper, + IconButton, + Tooltip, + Alert, + Link, +} from "@mui/material"; +import ContentCopyIcon from "@mui/icons-material/ContentCopy"; +import PageHeader from "../../components/shared/PageHeader"; + +const citations = [ + { + title: "Open LLM Leaderboard v2", + authors: + "Clémentine Fourrier, Nathan Habib, Alina Lozovskaya, Konrad Szafer, Thomas Wolf", + citation: `@misc{open-llm-leaderboard-v2, + author = {Clémentine Fourrier and Nathan Habib and Alina Lozovskaya and Konrad Szafer and Thomas Wolf}, + title = {Open LLM Leaderboard v2}, + year = {2024}, + publisher = {Hugging Face}, + howpublished = "\\url{https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard}", +}`, + type: "main", + }, + { + title: "Evaluation Framework", + authors: "Leo Gao et al.", + citation: `@software{eval-harness, + author = {Gao, Leo and Tow, Jonathan and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and McDonell, Kyle and Muennighoff, Niklas and Phang, Jason and Reynolds, Laria and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy}, + title = {A framework for few-shot language model evaluation}, + month = sep, + year = 2021, + publisher = {Zenodo}, + version = {v0.0.1}, + doi = {10.5281/zenodo.5371628}, + url = {https://doi.org/10.5281/zenodo.5371628}, +}`, + url: "https://doi.org/10.5281/zenodo.5371628", + }, +]; + +const priorWork = [ + { + title: "Open LLM Leaderboard v1", + authors: + "Edward Beeching, Clémentine Fourrier, Nathan Habib, Sheon Han, Nathan Lambert, Nazneen Rajani, Omar Sanseviero, Lewis Tunstall, Thomas Wolf", + citation: `@misc{open-llm-leaderboard-v1, + author = {Edward Beeching and Clémentine Fourrier and Nathan Habib and Sheon Han and Nathan Lambert and Nazneen Rajani and Omar Sanseviero and Lewis Tunstall and Thomas Wolf}, + title = {Open LLM Leaderboard (2023-2024)}, + year = {2023}, + publisher = {Hugging Face}, + howpublished = "\\url{https://huggingface.co/spaces/open-llm-leaderboard-old/open_llm_leaderboard}" +}`, + type: "main", + }, +]; + +const benchmarks = [ + { + title: "IFEval: Instruction-Following Evaluation", + authors: "Zhou et al.", + citation: `@misc{zhou2023instructionfollowingevaluationlargelanguage, + title={Instruction-Following Evaluation for Large Language Models}, + author={Jeffrey Zhou and Tianjian Lu and Swaroop Mishra and Siddhartha Brahma and Sujoy Basu and Yi Luan and Denny Zhou and Le Hou}, + year={2023}, + eprint={2311.07911}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2311.07911}, +}`, + url: "https://arxiv.org/abs/2311.07911", + }, + { + title: "BBH: Big-Bench Hard", + authors: "Suzgun et al.", + citation: `@misc{suzgun2022challengingbigbenchtaskschainofthought, + title={Challenging BIG-Bench Tasks and Whether Chain-of-Thought Can Solve Them}, + author={Mirac Suzgun and Nathan Scales and Nathanael Schärli and Sebastian Gehrmann and Yi Tay and Hyung Won Chung and Aakanksha Chowdhery and Quoc V. Le and Ed H. Chi and Denny Zhou and Jason Wei}, + year={2022}, + eprint={2210.09261}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2210.09261}, +}`, + url: "https://arxiv.org/abs/2210.09261", + }, + { + title: "MATH: Mathematics Aptitude Test of Heuristics - Level 5", + authors: "Hendrycks et al.", + citation: `@misc{hendrycks2021measuringmathematicalproblemsolving, + title={Measuring Mathematical Problem Solving With the MATH Dataset}, + author={Dan Hendrycks and Collin Burns and Saurav Kadavath and Akul Arora and Steven Basart and Eric Tang and Dawn Song and Jacob Steinhardt}, + year={2021}, + eprint={2103.03874}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2103.03874}, +}`, + url: "https://arxiv.org/abs/2103.03874", + }, + { + title: "GPQA: Graduate-Level Google-Proof Q&A", + authors: "Rein et al.", + citation: `@misc{rein2023gpqagraduatelevelgoogleproofqa, + title={GPQA: A Graduate-Level Google-Proof Q&A Benchmark}, + author={David Rein and Betty Li Hou and Asa Cooper Stickland and Jackson Petty and Richard Yuanzhe Pang and Julien Dirani and Julian Michael and Samuel R. Bowman}, + year={2023}, + eprint={2311.12022}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2311.12022}, +}`, + url: "https://arxiv.org/abs/2311.12022", + }, + { + title: "MuSR: Multistep Soft Reasoning", + authors: "Sprague et al.", + citation: `@misc{sprague2024musrtestinglimitschainofthought, + title={MuSR: Testing the Limits of Chain-of-thought with Multistep Soft Reasoning}, + author={Zayne Sprague and Xi Ye and Kaj Bostrom and Swarat Chaudhuri and Greg Durrett}, + year={2024}, + eprint={2310.16049}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2310.16049}, +}`, + url: "https://arxiv.org/abs/2310.16049", + }, + { + title: "MMLU-Pro: Massive Multitask Language Understanding Professional", + authors: "Wang et al.", + citation: `@misc{wang2024mmluprorobustchallengingmultitask, + title={MMLU-Pro: A More Robust and Challenging Multi-Task Language Understanding Benchmark}, + author={Yubo Wang and Xueguang Ma and Ge Zhang and Yuansheng Ni and Abhranil Chandra and Shiguang Guo and Weiming Ren and Aaran Arulraj and Xuan He and Ziyan Jiang and Tianle Li and Max Ku and Kai Wang and Alex Zhuang and Rongqi Fan and Xiang Yue and Wenhu Chen}, + year={2024}, + eprint={2406.01574}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2406.01574}, +}`, + url: "https://arxiv.org/abs/2406.01574", + }, +]; + +const CitationBlock = ({ citation, title, authors, url, type }) => { + const handleCopy = () => { + navigator.clipboard.writeText(citation); + }; + + return ( + + + + {title} + + + {authors} + + {url && ( + + View paper → + + )} + + + + + + + + + {citation} + + + + ); +}; + +function QuotePage() { + return ( + + + + + + The citations below include both the leaderboard itself and the + individual benchmarks used in our evaluation suite. + + + + + + Leaderboard + + + {citations.map((citation, index) => ( + + ))} + + + + + + Benchmarks + + + {benchmarks.map((benchmark, index) => ( + + ))} + + + + + + Prior Work + + + {priorWork.map((citation, index) => ( + + ))} + + + + ); +} + +export default QuotePage; diff --git a/frontend/src/pages/VoteModelPage/VoteModelPage.js b/frontend/src/pages/VoteModelPage/VoteModelPage.js new file mode 100644 index 0000000000000000000000000000000000000000..cbb0d14c55e8b0521bdea1c22f2af5b4f1e5667c --- /dev/null +++ b/frontend/src/pages/VoteModelPage/VoteModelPage.js @@ -0,0 +1,896 @@ +import React, { useState, useEffect } from "react"; +import { + Box, + Typography, + Paper, + Button, + Alert, + List, + ListItem, + CircularProgress, + Chip, + Divider, + IconButton, + Stack, + Link, + useTheme, + useMediaQuery, +} from "@mui/material"; +import AccessTimeIcon from "@mui/icons-material/AccessTime"; +import PersonIcon from "@mui/icons-material/Person"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import HowToVoteIcon from "@mui/icons-material/HowToVote"; +import { useAuth } from "../../hooks/useAuth"; +import PageHeader from "../../components/shared/PageHeader"; +import AuthContainer from "../../components/shared/AuthContainer"; +import { alpha } from "@mui/material/styles"; +import CheckIcon from "@mui/icons-material/Check"; + +const NoModelsToVote = () => ( + + + + No Models to Vote + + + There are currently no models waiting for votes. +
    + Check back later! +
    +
    +); + +const LOCAL_STORAGE_KEY = "pending_votes"; + +function VoteModelPage() { + const { isAuthenticated, user, loading: authLoading } = useAuth(); + const [pendingModels, setPendingModels] = useState([]); + const [loadingModels, setLoadingModels] = useState(true); + const [error, setError] = useState(null); + const [userVotes, setUserVotes] = useState(new Set()); + const [loadingVotes, setLoadingVotes] = useState({}); + const [localVotes, setLocalVotes] = useState(new Set()); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + // Create a unique identifier for a model + const getModelUniqueId = (model) => { + return `${model.name}_${model.precision}_${model.revision}`; + }; + + const formatWaitTime = (submissionTime) => { + if (!submissionTime) return "N/A"; + + const now = new Date(); + const submitted = new Date(submissionTime); + const diffInHours = Math.floor((now - submitted) / (1000 * 60 * 60)); + + // Less than 24 hours: show in hours + if (diffInHours < 24) { + return `${diffInHours}h`; + } + + // Less than 7 days: show in days + const diffInDays = Math.floor(diffInHours / 24); + if (diffInDays < 7) { + return `${diffInDays}d`; + } + + // More than 7 days: show in weeks + const diffInWeeks = Math.floor(diffInDays / 7); + return `${diffInWeeks}w`; + }; + + const getConfigVotes = (votesData, model) => { + // Créer l'identifiant unique du modèle + const modelUniqueId = getModelUniqueId(model); + + // Compter les votes du serveur + let serverVotes = 0; + for (const [key, config] of Object.entries(votesData.votes_by_config)) { + if ( + config.precision === model.precision && + config.revision === model.revision + ) { + serverVotes = config.count; + break; + } + } + + // Ajouter les votes en attente du localStorage + const pendingVote = localVotes.has(modelUniqueId) ? 1 : 0; + + return serverVotes + pendingVote; + }; + + const sortModels = (models) => { + // Trier d'abord par nombre de votes décroissant, puis par soumission de l'utilisateur + return [...models].sort((a, b) => { + // Comparer d'abord le nombre de votes + if (b.votes !== a.votes) { + return b.votes - a.votes; + } + + // Si l'utilisateur est connecté, mettre ses modèles en priorité + if (user) { + const aIsUserModel = a.submitter === user.username; + const bIsUserModel = b.submitter === user.username; + + if (aIsUserModel && !bIsUserModel) return -1; + if (!aIsUserModel && bIsUserModel) return 1; + } + + // Si égalité, trier par date de soumission (le plus récent d'abord) + return new Date(b.submission_time) - new Date(a.submission_time); + }); + }; + + // Add this function to handle localStorage + const updateLocalVotes = (modelUniqueId, action = "add") => { + const storedVotes = JSON.parse( + localStorage.getItem(LOCAL_STORAGE_KEY) || "[]" + ); + if (action === "add") { + if (!storedVotes.includes(modelUniqueId)) { + storedVotes.push(modelUniqueId); + } + } else { + const index = storedVotes.indexOf(modelUniqueId); + if (index > -1) { + storedVotes.splice(index, 1); + } + } + localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(storedVotes)); + setLocalVotes(new Set(storedVotes)); + }; + + useEffect(() => { + const fetchData = async () => { + try { + // Ne pas afficher le loading si on a déjà des données + if (pendingModels.length === 0) { + setLoadingModels(true); + } + setError(null); + + // Charger d'abord les votes en attente du localStorage + const storedVotes = JSON.parse( + localStorage.getItem(LOCAL_STORAGE_KEY) || "[]" + ); + const localVotesSet = new Set(storedVotes); + + // Préparer toutes les requêtes en parallèle + const [pendingModelsResponse, userVotesResponse] = await Promise.all([ + fetch("/api/models/pending"), + isAuthenticated && user + ? fetch(`/api/votes/user/${user.username}`) + : Promise.resolve(null), + ]); + + if (!pendingModelsResponse.ok) { + throw new Error("Failed to fetch pending models"); + } + + const modelsData = await pendingModelsResponse.json(); + const votedModels = new Set(); + + // Traiter les votes de l'utilisateur si connecté + if (userVotesResponse && userVotesResponse.ok) { + const votesData = await userVotesResponse.json(); + const userVotes = Array.isArray(votesData) ? votesData : []; + + userVotes.forEach((vote) => { + const uniqueId = `${vote.model}_${vote.precision || "unknown"}_${ + vote.revision || "main" + }`; + votedModels.add(uniqueId); + if (localVotesSet.has(uniqueId)) { + localVotesSet.delete(uniqueId); + updateLocalVotes(uniqueId, "remove"); + } + }); + } + + // Préparer et exécuter toutes les requêtes de votes en une seule fois + const modelVotesResponses = await Promise.all( + modelsData.map((model) => { + const [provider, modelName] = model.name.split("/"); + return fetch(`/api/votes/model/${provider}/${modelName}`) + .then((response) => + response.ok + ? response.json() + : { total_votes: 0, votes_by_config: {} } + ) + .catch(() => ({ total_votes: 0, votes_by_config: {} })); + }) + ); + + // Construire les modèles avec toutes les données + const modelsWithVotes = modelsData.map((model, index) => { + const votesData = modelVotesResponses[index]; + const modelUniqueId = getModelUniqueId(model); + const isVotedByUser = + votedModels.has(modelUniqueId) || localVotesSet.has(modelUniqueId); + + return { + ...model, + votes: getConfigVotes( + { + ...votesData, + votes_by_config: votesData.votes_by_config || {}, + }, + model + ), + votes_by_config: votesData.votes_by_config || {}, + wait_time: formatWaitTime(model.submission_time), + hasVoted: isVotedByUser, + }; + }); + + // Mettre à jour tous les états en une seule fois + const sortedModels = sortModels(modelsWithVotes); + + // Batch updates + const updates = () => { + setPendingModels(sortedModels); + setUserVotes(votedModels); + setLocalVotes(localVotesSet); + setLoadingModels(false); + }; + + updates(); + } catch (err) { + console.error("Error fetching data:", err); + setError(err.message); + setLoadingModels(false); + } + }; + + fetchData(); + }, [isAuthenticated, user]); + + // Modify the handleVote function + const handleVote = async (model) => { + if (!isAuthenticated) return; + + const modelUniqueId = getModelUniqueId(model); + + try { + setError(null); + setLoadingVotes((prev) => ({ ...prev, [modelUniqueId]: true })); + + // Add to localStorage immediately + updateLocalVotes(modelUniqueId, "add"); + + // Encode model name for URL + const encodedModelName = encodeURIComponent(model.name); + + const response = await fetch( + `/api/votes/${encodedModelName}?vote_type=up&user_id=${user.username}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + precision: model.precision, + revision: model.revision, + }), + } + ); + + if (!response.ok) { + // If the request fails, remove from localStorage + updateLocalVotes(modelUniqueId, "remove"); + throw new Error("Failed to submit vote"); + } + + // Refresh votes for this model with cache bypass + const [provider, modelName] = model.name.split("/"); + const timestamp = Date.now(); + const votesResponse = await fetch( + `/api/votes/model/${provider}/${modelName}?nocache=${timestamp}` + ); + + if (!votesResponse.ok) { + throw new Error("Failed to fetch updated votes"); + } + + const votesData = await votesResponse.json(); + console.log(`Updated votes for ${model.name}:`, votesData); // Debug log + + // Update model and resort the list + setPendingModels((models) => { + const updatedModels = models.map((m) => + getModelUniqueId(m) === getModelUniqueId(model) + ? { + ...m, + votes: getConfigVotes(votesData, m), + votes_by_config: votesData.votes_by_config || {}, + hasVoted: true, + } + : m + ); + const sortedModels = sortModels(updatedModels); + console.log("Updated and sorted models:", sortedModels); // Debug log + return sortedModels; + }); + + // Update user votes with unique ID + setUserVotes((prev) => new Set([...prev, getModelUniqueId(model)])); + } catch (err) { + console.error("Error voting:", err); + setError(err.message); + } finally { + // Clear loading state for this model + setLoadingVotes((prev) => ({ + ...prev, + [modelUniqueId]: false, + })); + } + }; + + // Modify the rendering logic to consider both server and local votes + // Inside the map function where you render models + const isVoted = (model) => { + const modelUniqueId = getModelUniqueId(model); + return userVotes.has(modelUniqueId) || localVotes.has(modelUniqueId); + }; + + if (authLoading || (loadingModels && pendingModels.length === 0)) { + return ( + + + + ); + } + + return ( + + + Help us prioritize which + models to evaluate next + + } + /> + + {error && ( + + {error} + + )} + + {/* Auth Status */} + {/* + {isAuthenticated ? ( + + + + + Connected as {user?.username} + + + + + + + ) : ( + + + Login to Vote + + + You need to be logged in with your Hugging Face account to vote + for models + + + + )} + */} + + + {/* Models List */} + + {/* Header - Always visible */} + + theme.palette.mode === "dark" + ? alpha(theme.palette.divider, 0.1) + : "grey.200", + bgcolor: (theme) => + theme.palette.mode === "dark" + ? alpha(theme.palette.background.paper, 0.5) + : "grey.50", + }} + > + + Models Pending Evaluation + + + + {/* Table Header */} + + + + Model + + + + + Votes + + + + + Priority + + + + + {/* Content */} + {loadingModels ? ( + + + + ) : pendingModels.length === 0 && !loadingModels ? ( + + ) : ( + + {pendingModels.map((model, index) => { + const isTopThree = index < 3; + return ( + + {index > 0 && } + + {/* Left side - Model info */} + + + {/* Model name and link */} + + + + {model.name} + + + + + + + + + + + {/* Metadata row */} + + + + + {model.wait_time} + + + + + + {model.submitter} + + + + + + + {/* Vote Column */} + + + + + + + + + + {model.votes > 999 ? "999" : model.votes} + + + + votes + + + + + + + {/* Priority Column */} + + + {isTopThree && ( + + HIGH + + )} + + #{index + 1} + + + } + size="medium" + variant={isTopThree ? "filled" : "outlined"} + sx={{ + height: 36, + minWidth: "100px", + bgcolor: isTopThree + ? (theme) => alpha(theme.palette.primary.main, 0.1) + : "transparent", + borderColor: isTopThree ? "primary.main" : "grey.300", + borderWidth: 2, + "& .MuiChip-label": { + px: 2, + fontSize: "0.95rem", + }, + }} + /> + + + + ); + })} + + )} + + + ); +} + +export default VoteModelPage;