Rúben Almeida
Add exception handling for incompatible models
0735f93
raw
history blame
4.4 kB
import zipfile
from abc import ABC
from typing import Optional, Union
from awq import AutoAWQForCausalLM
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
from tempfile import NamedTemporaryFile
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse, FileResponse
### FastAPI Initialization
@asynccontextmanager
async def lifespan(app:FastAPI):
yield
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
### -------
### DTO Definitions
class QuantizationConfig(ABC, BaseModel):
pass
class ConvertRequest(ABC, BaseModel):
hf_model_name: str
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
### -------
### Quantization Configurations
class AWQQuantizationConfig(QuantizationConfig):
zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
q_group_size: Optional[int] = Field(128, description="Quantization group size")
w_bit: Optional[int] = Field(4, description="Weight bit")
version: Optional[str] = Field("GEMM", description="Quantization version")
class GPTQQuantizationConfig(QuantizationConfig):
pass
class GGUFQuantizationConfig(QuantizationConfig):
pass
class AWQConvertionRequest(ConvertRequest):
quantization_config: Optional[AWQQuantizationConfig] = Field(
default_factory=lambda: AWQQuantizationConfig(),
description="AWQ quantization configuration"
)
class GPTQConvertionRequest(ConvertRequest):
quantization_config: Optional[GPTQQuantizationConfig] = Field(
default_factory=lambda: GPTQQuantizationConfig(),
description="GPTQ quantization configuration"
)
class GGUFConvertionRequest(ConvertRequest):
quantization_config: Optional[GGUFQuantizationConfig] = Field(
default_factory=lambda: GGUFQuantizationConfig(),
description="GGUF quantization configuration"
)
### -------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
return RedirectResponse(url='/docs')
### FastAPI Endpoints
@app.post("/convert_awq", response_model=None)
def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]:
model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name)
tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True)
try:
model.quantize(tokenizer, quant_config=request.quantization_config.model_dump())
except TypeError as e:
raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}")
if request.hf_push_repo:
model.save_quantized(request.hf_push_repo)
tokenizer.save_pretrained(request.hf_push_repo)
return {
"status": "ok",
"message": f"Model saved to {request.hf_push_repo}",
}
# Return a zip file with the converted model
with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip:
zip_file_path = temp_zip.name
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
# Save the model and tokenizer files to the zip
model.save_quantized(zipf)
tokenizer.save_pretrained(zipf)
return FileResponse(
zip_file_path,
media_type='application/zip',
filename=f"{request.hf_model_name}.zip"
)
raise HTTPException(status_code=500, detail="Failed to convert model")
@app.post("/convert_gpt_q", response_model=None)
def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]:
raise HTTPException(status_code=501, detail="Not implemented yet")
@app.post("/convert_gguf", response_model=None)
def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]:
raise HTTPException(status_code=501, detail="Not implemented yet")
@app.get("/health")
def read_root():
return {"status": "ok"}