Spaces:
Sleeping
Sleeping
import zipfile | |
from abc import ABC | |
from typing import Optional, Union | |
from awq import AutoAWQForCausalLM | |
from pydantic import BaseModel, Field | |
from transformers import AutoTokenizer | |
from tempfile import NamedTemporaryFile | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import RedirectResponse, FileResponse | |
### FastAPI Initialization | |
async def lifespan(app:FastAPI): | |
yield | |
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan) | |
### ------- | |
### DTO Definitions | |
class QuantizationConfig(ABC, BaseModel): | |
pass | |
class ConvertRequest(ABC, BaseModel): | |
hf_model_name: str | |
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name") | |
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models") | |
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.") | |
### ------- | |
### Quantization Configurations | |
class AWQQuantizationConfig(QuantizationConfig): | |
zero_point: Optional[bool] = Field(True, description="Use zero point quantization") | |
q_group_size: Optional[int] = Field(128, description="Quantization group size") | |
w_bit: Optional[int] = Field(4, description="Weight bit") | |
version: Optional[str] = Field("GEMM", description="Quantization version") | |
class GPTQQuantizationConfig(QuantizationConfig): | |
pass | |
class GGUFQuantizationConfig(QuantizationConfig): | |
pass | |
class AWQConvertionRequest(ConvertRequest): | |
quantization_config: Optional[AWQQuantizationConfig] = Field( | |
default_factory=lambda: AWQQuantizationConfig(), | |
description="AWQ quantization configuration" | |
) | |
class GPTQConvertionRequest(ConvertRequest): | |
quantization_config: Optional[GPTQQuantizationConfig] = Field( | |
default_factory=lambda: GPTQQuantizationConfig(), | |
description="GPTQ quantization configuration" | |
) | |
class GGUFConvertionRequest(ConvertRequest): | |
quantization_config: Optional[GGUFQuantizationConfig] = Field( | |
default_factory=lambda: GGUFQuantizationConfig(), | |
description="GGUF quantization configuration" | |
) | |
### ------- | |
def redirect_to_docs(): | |
return RedirectResponse(url='/docs') | |
### FastAPI Endpoints | |
def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]: | |
model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True) | |
try: | |
model.quantize(tokenizer, quant_config=request.quantization_config.model_dump()) | |
except TypeError as e: | |
raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}") | |
if request.hf_push_repo: | |
model.save_quantized(request.hf_push_repo) | |
tokenizer.save_pretrained(request.hf_push_repo) | |
return { | |
"status": "ok", | |
"message": f"Model saved to {request.hf_push_repo}", | |
} | |
# Return a zip file with the converted model | |
with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip: | |
zip_file_path = temp_zip.name | |
with zipfile.ZipFile(zip_file_path, 'w') as zipf: | |
# Save the model and tokenizer files to the zip | |
model.save_quantized(zipf) | |
tokenizer.save_pretrained(zipf) | |
return FileResponse( | |
zip_file_path, | |
media_type='application/zip', | |
filename=f"{request.hf_model_name}.zip" | |
) | |
raise HTTPException(status_code=500, detail="Failed to convert model") | |
def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]: | |
raise HTTPException(status_code=501, detail="Not implemented yet") | |
def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]: | |
raise HTTPException(status_code=501, detail="Not implemented yet") | |
def read_root(): | |
return {"status": "ok"} |