lukehinds's picture
Big refactor
1264ff3
raw
history blame
7.76 kB
import json
import os
import logging
from datetime import datetime, timezone
from typing import Dict, Tuple, Optional, Any
from src.display.formatting import styled_error, styled_message, styled_warning
from src.envs import API
from src.submission.check_validity import (
already_submitted_models,
check_model_card,
get_model_size,
is_model_on_hub,
check_safetensors_format,
)
from src.config import (
API_TOKEN,
QUEUE_REPO,
EVAL_REQUESTS_PATH,
ALLOWED_WEIGHT_TYPES,
DEFAULT_REVISION,
LOG_LEVEL,
EVALUATION_WAIT_TIME
)
REQUESTED_MODELS: Optional[Dict[str, Any]] = None
USERS_TO_SUBMISSION_DATES: Optional[Dict[str, Any]] = None
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
logger = logging.getLogger(__name__)
def validate_input(model_type: Optional[str], weight_type: str) -> Optional[str]:
"""Validate input parameters."""
if model_type is None or model_type == "":
return styled_error("Please select a model type.")
if weight_type not in ALLOWED_WEIGHT_TYPES:
return styled_error(f"Invalid weight type. Must be one of: {', '.join(ALLOWED_WEIGHT_TYPES)}")
if weight_type != "Safetensors" and weight_type != "GGUF":
return styled_error(
"Only Safetensors format is accepted for new submissions (or GGUF for quantized models). Please convert your model using:\n"
"```python\n"
"from transformers import AutoModelForCausalLM\n"
"from safetensors.torch import save_file\n\n"
"model = AutoModelForCausalLM.from_pretrained('your-model')\n"
"state_dict = model.state_dict()\n"
"save_file(state_dict, 'model.safetensors')\n"
"```"
)
return None
def check_model_existence(model: str, revision: str, token: str) -> Optional[str]:
"""Check if the model exists on the hub."""
if revision == "":
revision = DEFAULT_REVISION
model_on_hub, error, _ = is_model_on_hub(model_name=model, revision=revision, token=token, test_tokenizer=True)
if not model_on_hub:
return styled_error(f'Model "{model}" {error}')
return None
def get_model_information(model: str, revision: str, weight_type: str) -> Tuple[Optional[Any], Optional[str]]:
"""Get model information and perform necessary checks."""
if weight_type != "GGUF":
safetensors_ok, error_msg = check_safetensors_format(model, revision, API_TOKEN)
if not safetensors_ok:
return None, styled_error(error_msg)
try:
model_info = API.model_info(repo_id=model, revision=revision)
except Exception as e:
logger.error(f"Failed to get model info: {e}")
return None, styled_error("Could not get your model information. Please fill it up properly.")
try:
license = model_info.cardData["license"]
except Exception as e:
logger.error(f"Failed to get license info: {e}")
return None, styled_error("Please select a license for your model")
modelcard_OK, error_msg = check_model_card(model)
if not modelcard_OK:
return None, styled_error(error_msg)
return model_info, None
def create_eval_entry(model: str, base_model: str, revision: str, precision: str, weight_type: str, model_type: str, model_info: Any, model_size: float) -> Dict[str, Any]:
"""Create the evaluation entry dictionary."""
current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
return {
"model": model,
"base_model": base_model,
"revision": revision,
"precision": precision,
"weight_type": weight_type,
"status": "PENDING",
"submitted_time": current_time,
"model_type": model_type,
"likes": model_info.likes,
"params": model_size,
"license": model_info.cardData["license"],
"private": False,
}
def add_new_eval(
model: str,
base_model: str,
revision: str,
precision: str,
weight_type: str,
model_type: str,
) -> str:
"""
Add a new model evaluation request to the queue.
Args:
model (str): The name of the model to evaluate.
base_model (str): The name of the base model (for delta or adapter weights).
revision (str): The revision of the model to evaluate.
precision (str): The precision of the model weights.
weight_type (str): The format of the model weights.
model_type (str): The type of the model.
Returns:
str: A message indicating the result of the evaluation request.
"""
global REQUESTED_MODELS
global USERS_TO_SUBMISSION_DATES
global EVAL_REQUESTS_PATH
# Check and modify EVAL_REQUESTS_PATH at the beginning
if not EVAL_REQUESTS_PATH or EVAL_REQUESTS_PATH == "YOUR_EVAL_REQUESTS_PATH_HERE":
return styled_error("EVAL_REQUESTS_PATH is not properly configured. Please check your configuration.")
# Ensure EVAL_REQUESTS_PATH ends with 'eval-queue'
if not EVAL_REQUESTS_PATH.endswith('eval-queue'):
EVAL_REQUESTS_PATH = os.path.join(EVAL_REQUESTS_PATH, 'eval-queue')
# Input validation
if not all([model, revision, precision, weight_type, model_type]):
return styled_error("All fields except base_model are required.")
if not REQUESTED_MODELS:
REQUESTED_MODELS, USERS_TO_SUBMISSION_DATES = already_submitted_models(EVAL_REQUESTS_PATH)
user_name, model_path = model.split("/") if "/" in model else ("", model)
precision = precision.split(" ")[0]
error = validate_input(model_type, weight_type)
if error:
return error
error = check_model_existence(model, revision, API_TOKEN)
if error:
return error
model_info, error = get_model_information(model, revision, weight_type)
if error:
return error
model_size = get_model_size(model_info=model_info, precision=precision)
eval_entry = create_eval_entry(model, base_model, revision, precision, weight_type, model_type, model_info, model_size)
# Check for duplicate submission
if f"{model}_{revision}_{precision}" in REQUESTED_MODELS:
return styled_warning("This model has been already submitted.")
logger.info("Creating eval file")
OUT_DIR = os.path.join(EVAL_REQUESTS_PATH, user_name)
os.makedirs(OUT_DIR, exist_ok=True)
out_path = os.path.join(OUT_DIR, f"{model_path}_eval_request_False_{precision}_{weight_type}.json")
try:
with open(out_path, "w") as f:
json.dump(eval_entry, f)
except IOError as e:
logger.error(f"Failed to write eval file: {e}")
return styled_error(f"Failed to create eval file: {e}")
logger.info("Uploading eval file")
try:
# Get the relative path from EVAL_REQUESTS_PATH
rel_path = os.path.relpath(out_path, EVAL_REQUESTS_PATH)
API.upload_file(
path_or_fileobj=out_path,
path_in_repo=rel_path,
repo_id=QUEUE_REPO,
repo_type="dataset",
commit_message=f"Add {model} to eval queue",
)
except Exception as e:
logger.error(f"Failed to upload eval file: {e}")
return styled_error(f"Failed to upload eval file: {e}")
# Remove the local file
try:
os.remove(out_path)
except OSError as e:
logger.warning(f"Failed to remove local eval file: {e}")
return styled_message(
f"Your request has been submitted to the evaluation queue!\n"
f"The model will be evaluated for:\n"
f"1. Safetensors compliance\n"
f"2. Security awareness using the stacklok/insecure-code dataset\n"
f"Please wait for up to {EVALUATION_WAIT_TIME} minutes for the model to show in the PENDING list."
)