import zipfile from abc import ABC from typing import Optional, Union from awq import AutoAWQForCausalLM from pydantic import BaseModel, Field from transformers import AutoTokenizer from tempfile import NamedTemporaryFile from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse, FileResponse ### FastAPI Initialization @asynccontextmanager async def lifespan(app:FastAPI): yield app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan) ### ------- ### DTO Definitions class QuantizationConfig(ABC, BaseModel): pass class ConvertRequest(ABC, BaseModel): hf_model_name: str hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name") hf_token: Optional[str] = Field(None, description="Hugging Face token for private models") hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.") ### ------- ### Quantization Configurations class AWQQuantizationConfig(QuantizationConfig): zero_point: Optional[bool] = Field(True, description="Use zero point quantization") q_group_size: Optional[int] = Field(128, description="Quantization group size") w_bit: Optional[int] = Field(4, description="Weight bit") version: Optional[str] = Field("GEMM", description="Quantization version") class GPTQQuantizationConfig(QuantizationConfig): pass class GGUFQuantizationConfig(QuantizationConfig): pass class AWQConvertionRequest(ConvertRequest): quantization_config: Optional[AWQQuantizationConfig] = Field( default_factory=lambda: AWQQuantizationConfig(), description="AWQ quantization configuration" ) class GPTQConvertionRequest(ConvertRequest): quantization_config: Optional[GPTQQuantizationConfig] = Field( default_factory=lambda: GPTQQuantizationConfig(), description="GPTQ quantization configuration" ) class GGUFConvertionRequest(ConvertRequest): quantization_config: Optional[GGUFQuantizationConfig] = Field( default_factory=lambda: GGUFQuantizationConfig(), description="GGUF quantization configuration" ) ### ------- @app.get("/", include_in_schema=False) def redirect_to_docs(): return RedirectResponse(url='/docs') ### FastAPI Endpoints @app.post("/convert_awq", response_model=None) def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]: model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name) tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True) try: model.quantize(tokenizer, quant_config=request.quantization_config.model_dump()) except TypeError as e: raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}") if request.hf_push_repo: model.save_quantized(request.hf_push_repo) tokenizer.save_pretrained(request.hf_push_repo) return { "status": "ok", "message": f"Model saved to {request.hf_push_repo}", } # Return a zip file with the converted model with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip: zip_file_path = temp_zip.name with zipfile.ZipFile(zip_file_path, 'w') as zipf: # Save the model and tokenizer files to the zip model.save_quantized(zipf) tokenizer.save_pretrained(zipf) return FileResponse( zip_file_path, media_type='application/zip', filename=f"{request.hf_model_name}.zip" ) raise HTTPException(status_code=500, detail="Failed to convert model") @app.post("/convert_gpt_q", response_model=None) def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]: raise HTTPException(status_code=501, detail="Not implemented yet") @app.post("/convert_gguf", response_model=None) def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]: raise HTTPException(status_code=501, detail="Not implemented yet") @app.get("/health") def read_root(): return {"status": "ok"}