import asyncio import logging import os import sys from contextlib import asynccontextmanager from datetime import datetime from typing import List, Optional import chromadb import dateutil.parser import httpx import polars as pl import torch from cashews import cache from chromadb.utils import embedding_functions from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer # Configuration constants MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base" BATCH_SIZE = 2000 CACHE_TTL = "24h" TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data if torch.cuda.is_available(): DEVICE = "cuda" elif torch.backends.mps.is_available(): DEVICE = "mps" else: DEVICE = "cpu" tokenizer = AutoTokenizer.from_pretrained( "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" ) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) LOCAL = False if sys.platform == "darwin": LOCAL = True DATA_DIR = "data" if LOCAL else "/data" # Configure cache cache.setup("mem://", size_limit="8gb") # Initialize ChromaDB client client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Setup setup_database() yield # Cleanup await cache.close() app = FastAPI(lifespan=lifespan) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.hf.space", # Allow all Hugging Face Spaces "https://*.huggingface.co", # Allow all Hugging Face domains # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define the embedding function at module level def get_embedding_function(): logger.info(f"Using device: {DEVICE}") return embedding_functions.SentenceTransformerEmbeddingFunction( model_name="nomic-ai/modernbert-embed-base", device=DEVICE ) def setup_database(): try: embedding_function = get_embedding_function() dataset_collection = client.get_or_create_collection( embedding_function=embedding_function, name="dataset_cards", metadata={"hnsw:space": "cosine"}, ) model_collection = client.get_or_create_collection( embedding_function=embedding_function, name="model_cards", metadata={"hnsw:space": "cosine"}, ) # Load dataset data df = pl.scan_parquet( "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" ) df = df.filter( pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() ) df = df.filter( pl.col("datasetId") .str.contains_any( ["gemma-2-2B-it-thinking-function_calling-V0"] ) # course model that's not useful for retrieving .not_() ) # Get the most recent last_modified date from the collection latest_update = None if dataset_collection.count() > 0: metadata = dataset_collection.get(include=["metadatas"]).get("metadatas") logger.info(f"Found {len(metadata)} existing records in collection") last_modifieds = [ dateutil.parser.parse(m.get("last_modified")) for m in metadata ] latest_update = max(last_modifieds) logger.info(f"Most recent record in DB from: {latest_update}") logger.info(f"Oldest record in DB from: {min(last_modifieds)}") # Filter and process only newer records df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"]) # Log some stats about the incoming data sample_dates = df.select("last_modified").limit(5).collect() logger.info(f"Sample of incoming dates: {sample_dates}") total_incoming = df.select(pl.len()).collect().item() logger.info(f"Total incoming records: {total_incoming}") if latest_update: logger.info(f"Filtering records newer than {latest_update}") df = df.filter(pl.col("last_modified") > latest_update) filtered_count = df.select(pl.len()).collect().item() logger.info(f"Found {filtered_count} records to update after filtering") df = df.collect() total_rows = len(df) if total_rows > 0: logger.info(f"Updating dataset collection with {total_rows} new records") logger.info( f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}" ) for i in range(0, total_rows, BATCH_SIZE): batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) batch_size = len(batch_df) logger.info( f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records " f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})" ) dataset_collection.upsert( ids=batch_df.select(["datasetId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), } for likes, downloads, last_modified in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), ) ], ) logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records") logger.info( f"Database initialized with {dataset_collection.count():,} total rows" ) # Load model data model_df = pl.scan_parquet( "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" ) model_row_count = model_df.select(pl.len()).collect().item() logger.info(f"Row count of new model data: {model_row_count}") if model_collection.count() < model_row_count: model_df = model_df.select( [ "modelId", "summary", "likes", "downloads", "last_modified", "param_count", ] ) model_df = model_df.collect() total_rows = len(model_df) for i in range(0, total_rows, BATCH_SIZE): batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) model_collection.upsert( ids=batch_df.select(["modelId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), "param_count": int(param_count) if param_count is not None else 0, } for likes, downloads, last_modified, param_count in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), batch_df.select(["param_count"]).to_series().to_list(), ) ], ) logger.info( f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" ) logger.info( f"Model database initialized with {model_collection.count():,} rows" ) except Exception as e: logger.error(f"Setup error: {e}") # Run setup on startup setup_database() class QueryResult(BaseModel): dataset_id: str similarity: float summary: str likes: int downloads: int class QueryResponse(BaseModel): results: List[QueryResult] class ModelQueryResult(BaseModel): model_id: str similarity: float summary: str likes: int downloads: int param_count: Optional[int] = None class ModelQueryResponse(BaseModel): results: List[ModelQueryResult] @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/search/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def search_datasets( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection( name="dataset_cards", embedding_function=get_embedding_function() ) results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results(results, "dataset", k, sort_by) return QueryResponse(results=query_results) except Exception as e: logger.error(f"Search error: {str(e)}") raise HTTPException(status_code=500, detail="Search failed") @app.get("/similarity/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_datasets( dataset_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("dataset_cards") results = collection.get(ids=[dataset_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Dataset ID '{dataset_id}' not found" ) results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results( results, "dataset", k, sort_by, dataset_id ) return QueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Similarity search failed") @app.get("/search/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def search_models( query: str, k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"], description="Sort method for results", ), min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), min_param_count: int = Query( default=0, ge=0, description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", ), max_param_count: Optional[int] = Query( default=None, ge=0, description="Maximum parameter count (None means no upper limit)", ), ): """ Search for models based on a text query with optional filtering. - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded - param_count=0 indicates missing/unknown parameter count in the dataset """ try: collection = client.get_collection( name="model_cards", embedding_function=get_embedding_function() ) where_conditions = [] if min_likes > 0: where_conditions.append({"likes": {"$gte": min_likes}}) if min_downloads > 0: where_conditions.append({"downloads": {"$gte": min_downloads}}) # Add parameter count filters using_param_filters = min_param_count > 0 or max_param_count is not None if using_param_filters: # Always exclude zero param count when using any parameter filters where_conditions.append({"param_count": {"$gt": 0}}) if min_param_count > 0: where_conditions.append({"param_count": {"$gte": min_param_count}}) if max_param_count is not None: where_conditions.append({"param_count": {"$lte": max_param_count}}) # Handle where clause creation based on number of conditions where_clause = None if len(where_conditions) > 1: where_clause = {"$and": where_conditions} elif len(where_conditions) == 1: where_clause = where_conditions[0] # Single condition without $and results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where=where_clause, ) query_results = await process_search_results(results, "model", k, sort_by) return ModelQueryResponse(results=query_results) except Exception as e: logger.error(f"Model search error: {str(e)}") raise HTTPException(status_code=500, detail="Model search failed") @app.get("/similarity/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_models( model_id: str, k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"], description="Sort method for results", ), min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), min_param_count: int = Query( default=0, ge=0, description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", ), max_param_count: Optional[int] = Query( default=None, ge=0, description="Maximum parameter count (None means no upper limit)", ), ): """ Find similar models to a specified model with optional filtering. - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded - param_count=0 indicates missing/unknown parameter count in the dataset """ try: collection = client.get_collection("model_cards") results = collection.get(ids=[model_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Model ID '{model_id}' not found" ) where_conditions = [] if min_likes > 0: where_conditions.append({"likes": {"$gte": min_likes}}) if min_downloads > 0: where_conditions.append({"downloads": {"$gte": min_downloads}}) # Add parameter count filters using_param_filters = min_param_count > 0 or max_param_count is not None if using_param_filters: # Always exclude zero param count when using any parameter filters where_conditions.append({"param_count": {"$gt": 0}}) if min_param_count > 0: where_conditions.append({"param_count": {"$gte": min_param_count}}) if max_param_count is not None: where_conditions.append({"param_count": {"$lte": max_param_count}}) # Handle where clause creation based on number of conditions where_clause = None if len(where_conditions) > 1: where_clause = {"$and": where_conditions} elif len(where_conditions) == 1: where_clause = where_conditions[0] # Single condition without $and results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where=where_clause, ) query_results = await process_search_results( results, "model", k, sort_by, model_id ) return ModelQueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Model similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Model similarity search failed") @cache(ttl="1h") async def get_trending_score(item_id: str, item_type: str) -> float: """Fetch trending score for a model or dataset from HuggingFace API""" try: async with httpx.AsyncClient() as client: endpoint = "models" if item_type == "model" else "datasets" response = await client.get( f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore" ) response.raise_for_status() return response.json().get("trendingScore", 0) except Exception as e: logger.error( f"Error fetching trending score for {item_type} {item_id}: {str(e)}" ) return 0 async def process_search_results(results, id_field, k, sort_by, exclude_id=None): """Process search results into a standardized format.""" query_results = [] # Create base results for i in range(len(results["ids"][0])): current_id = results["ids"][0][i] if exclude_id and current_id == exclude_id: continue result = { f"{id_field}_id": current_id, "similarity": float(results["distances"][0][i]), "summary": results["documents"][0][i], "likes": results["metadatas"][0][i]["likes"], "downloads": results["metadatas"][0][i]["downloads"], } # Add param_count for models if it exists in metadata if id_field == "model" and "param_count" in results["metadatas"][0][i]: result["param_count"] = results["metadatas"][0][i]["param_count"] if id_field == "dataset": query_results.append(QueryResult(**result)) else: query_results.append(ModelQueryResult(**result)) # Handle sorting if sort_by == "trending": # Fetch trending scores for all results trending_scores = {} async with httpx.AsyncClient() as client: tasks = [ get_trending_score( getattr(result, f"{id_field}_id"), "model" if id_field == "model" else "dataset", ) for result in query_results ] scores = await asyncio.gather(*tasks) trending_scores = { getattr(result, f"{id_field}_id"): score for result, score in zip(query_results, scores) } # Sort by trending score query_results.sort( key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0), reverse=True, ) query_results = query_results[:k] elif sort_by != "similarity": query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) query_results = query_results[:k] elif exclude_id: # We fetched extra for similarity + exclude_id case query_results = query_results[:k] return query_results async def fetch_trending_models(): """Fetch trending models from HuggingFace API""" async with httpx.AsyncClient() as client: response = await client.get("https://huggingface.co/api/models") response.raise_for_status() return response.json() @cache(ttl=TRENDING_CACHE_TTL) async def get_trending_models_with_summaries( limit: int = 10, min_likes: int = 0, min_downloads: int = 0, min_param_count: int = 0, max_param_count: Optional[int] = None, ) -> List[ModelQueryResult]: """Fetch trending models and combine with summaries from database""" try: # Fetch trending models trending_models = await fetch_trending_models() # Filter by minimum likes/downloads trending_models = [ model for model in trending_models if model.get("likes", 0) >= min_likes and model.get("downloads", 0) >= min_downloads ] # Sort by trending score trending_models = sorted( trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True ) # Fetch up to 3x the limit (buffer for filtering) or all available if fewer # This ensures we have enough models to filter from fetch_limit = min(len(trending_models), limit * 3) trending_models = trending_models[:fetch_limit] # Get model IDs model_ids = [model["modelId"] for model in trending_models] # Fetch summaries from ChromaDB collection = client.get_collection("model_cards") summaries = collection.get(ids=model_ids, include=["documents", "metadatas"]) # Create mapping of model_id to summary and metadata id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"])) # Log parameters for debugging print( f"Filter params - min_param_count: {min_param_count}, max_param_count: {max_param_count}" ) # Combine data - collect all results first all_results = [] for model in trending_models: if model["modelId"] in id_to_summary: metadata = id_to_metadata.get(model["modelId"], {}) param_count = metadata.get("param_count", 0) # Log model parameter counts print(f"Model: {model['modelId']}, param_count: {param_count}") result = ModelQueryResult( model_id=model["modelId"], similarity=1.0, # Not applicable for trending summary=id_to_summary[model["modelId"]], likes=model.get("likes", 0), downloads=model.get("downloads", 0), param_count=param_count, ) all_results.append(result) # Apply parameter filtering after collecting all results filtered_results = all_results # Check if any parameter filtering is being applied using_param_filters = min_param_count > 0 or max_param_count is not None # Only filter by params if we have specific parameter constraints if using_param_filters: filtered_results = [] for result in all_results: should_include = True # Always exclude models with param_count=0 when any parameter filtering is active if result.param_count == 0: print( f"Filtering out {result.model_id} - has param_count=0 but parameter filtering is active" ) should_include = False # Apply min param filter if specified elif min_param_count > 0 and result.param_count < min_param_count: print( f"Filtering out {result.model_id} - param_count {result.param_count} < min_param_count {min_param_count}" ) should_include = False # Apply max param filter if specified elif ( max_param_count is not None and result.param_count > max_param_count ): print( f"Filtering out {result.model_id} - param_count {result.param_count} > max_param_count {max_param_count}" ) should_include = False if should_include: filtered_results.append(result) print(f"After filtering: {len(filtered_results)} models remain") # Finally limit to the requested number return filtered_results[:limit] except Exception as e: logger.error(f"Error fetching trending models: {str(e)}") raise HTTPException(status_code=500, detail="Failed to fetch trending models") @app.get("/trending/models", response_model=ModelQueryResponse) async def get_trending_models( limit: int = Query( default=10, ge=1, le=100, description="Number of results to return" ), min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), min_param_count: int = Query( default=0, ge=0, description="Minimum parameter count (models with param_count=0 will be excluded if any parameter filter is used)", ), max_param_count: Optional[int] = Query( default=None, ge=0, description="Maximum parameter count (None means no upper limit)", ), ): """ Get trending models with their summaries and optional filtering. - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded - param_count=0 indicates missing/unknown parameter count in the dataset """ print( f"Request for trending models with params: limit={limit}, min_likes={min_likes}, min_downloads={min_downloads}, min_param_count={min_param_count}, max_param_count={max_param_count}" ) results = await get_trending_models_with_summaries( limit=limit, min_likes=min_likes, min_downloads=min_downloads, min_param_count=min_param_count, max_param_count=max_param_count, ) print(f"Returning {len(results)} trending model results") return ModelQueryResponse(results=results) async def fetch_trending_datasets(): """Fetch trending datasets from HuggingFace API""" async with httpx.AsyncClient() as client: response = await client.get("https://huggingface.co/api/datasets") response.raise_for_status() return response.json() @cache(ttl=TRENDING_CACHE_TTL) async def get_trending_datasets_with_summaries( limit: int = 10, min_likes: int = 0, min_downloads: int = 0, ) -> List[QueryResult]: """Fetch trending datasets and combine with summaries from database""" try: # Fetch trending datasets trending_datasets = await fetch_trending_datasets() # Filter by minimum likes/downloads trending_datasets = [ dataset for dataset in trending_datasets if dataset.get("likes", 0) >= min_likes and dataset.get("downloads", 0) >= min_downloads ] # Sort by trending score and limit trending_datasets = sorted( trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True )[:limit] # Get dataset IDs dataset_ids = [dataset["id"] for dataset in trending_datasets] # Fetch summaries from ChromaDB collection = client.get_collection("dataset_cards") summaries = collection.get(ids=dataset_ids, include=["documents"]) # Create mapping of dataset_id to summary id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) # Combine data results = [] for dataset in trending_datasets: if dataset["id"] in id_to_summary: result = QueryResult( dataset_id=dataset["id"], similarity=1.0, # Not applicable for trending summary=id_to_summary[dataset["id"]], likes=dataset.get("likes", 0), downloads=dataset.get("downloads", 0), ) results.append(result) return results except Exception as e: logger.error(f"Error fetching trending datasets: {str(e)}") raise HTTPException(status_code=500, detail="Failed to fetch trending datasets") @app.get("/trending/datasets", response_model=QueryResponse) async def get_trending_datasets( limit: int = Query(default=10, ge=1, le=100), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): """Get trending datasets with their summaries""" results = await get_trending_datasets_with_summaries( limit=limit, min_likes=min_likes, min_downloads=min_downloads ) return QueryResponse(results=results) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)