Spaces:
Sleeping
Sleeping
File size: 4,398 Bytes
af9aed3 0735f93 af9aed3 5e9b3af d75b820 5e9b3af 6af49e3 d75b820 af9aed3 d75b820 5e9b3af d75b820 5e9b3af d75b820 5e9b3af 0735f93 5e9b3af af9aed3 5e9b3af 0735f93 5e9b3af 0735f93 af9aed3 0735f93 af9aed3 6af49e3 af9aed3 6af49e3 af9aed3 6af49e3 af9aed3 d75b820 0735f93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import zipfile
from abc import ABC
from typing import Optional, Union
from awq import AutoAWQForCausalLM
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
from tempfile import NamedTemporaryFile
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse, FileResponse
### FastAPI Initialization
@asynccontextmanager
async def lifespan(app:FastAPI):
yield
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
### -------
### DTO Definitions
class QuantizationConfig(ABC, BaseModel):
pass
class ConvertRequest(ABC, BaseModel):
hf_model_name: str
hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
### -------
### Quantization Configurations
class AWQQuantizationConfig(QuantizationConfig):
zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
q_group_size: Optional[int] = Field(128, description="Quantization group size")
w_bit: Optional[int] = Field(4, description="Weight bit")
version: Optional[str] = Field("GEMM", description="Quantization version")
class GPTQQuantizationConfig(QuantizationConfig):
pass
class GGUFQuantizationConfig(QuantizationConfig):
pass
class AWQConvertionRequest(ConvertRequest):
quantization_config: Optional[AWQQuantizationConfig] = Field(
default_factory=lambda: AWQQuantizationConfig(),
description="AWQ quantization configuration"
)
class GPTQConvertionRequest(ConvertRequest):
quantization_config: Optional[GPTQQuantizationConfig] = Field(
default_factory=lambda: GPTQQuantizationConfig(),
description="GPTQ quantization configuration"
)
class GGUFConvertionRequest(ConvertRequest):
quantization_config: Optional[GGUFQuantizationConfig] = Field(
default_factory=lambda: GGUFQuantizationConfig(),
description="GGUF quantization configuration"
)
### -------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
return RedirectResponse(url='/docs')
### FastAPI Endpoints
@app.post("/convert_awq", response_model=None)
def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]:
model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name)
tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True)
try:
model.quantize(tokenizer, quant_config=request.quantization_config.model_dump())
except TypeError as e:
raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}")
if request.hf_push_repo:
model.save_quantized(request.hf_push_repo)
tokenizer.save_pretrained(request.hf_push_repo)
return {
"status": "ok",
"message": f"Model saved to {request.hf_push_repo}",
}
# Return a zip file with the converted model
with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip:
zip_file_path = temp_zip.name
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
# Save the model and tokenizer files to the zip
model.save_quantized(zipf)
tokenizer.save_pretrained(zipf)
return FileResponse(
zip_file_path,
media_type='application/zip',
filename=f"{request.hf_model_name}.zip"
)
raise HTTPException(status_code=500, detail="Failed to convert model")
@app.post("/convert_gpt_q", response_model=None)
def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]:
raise HTTPException(status_code=501, detail="Not implemented yet")
@app.post("/convert_gguf", response_model=None)
def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]:
raise HTTPException(status_code=501, detail="Not implemented yet")
@app.get("/health")
def read_root():
return {"status": "ok"} |