|
from typing import Optional |
|
from awq import AutoAWQForCausalLM |
|
from pydantic import BaseModel, Field |
|
from transformers import AutoTokenizer |
|
from contextlib import asynccontextmanager |
|
from starlette.responses import FileResponse |
|
from fastapi import FastAPI, HTTPException, RedirectResponse |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app:FastAPI): |
|
yield |
|
|
|
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan) |
|
|
|
|
|
|
|
class QuantizationConfig(BaseModel): |
|
zero_point: Optional[bool] = Field(True, description="Use zero point quantization") |
|
q_group_size: Optional[int] = Field(128, description="Quantization group size") |
|
w_bit: Optional[int] = Field(4, description="Weight bit") |
|
version: Optional[str] = Field("GEMM", description="Quantization version") |
|
|
|
class ConvertRequest(BaseModel): |
|
hf_model_name: str |
|
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models") |
|
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.") |
|
quantization_config: QuantizationConfig = Field(QuantizationConfig(), description="Quantization configuration") |
|
|
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
def redirect_to_docs(): |
|
return RedirectResponse(url='/docs') |
|
|
|
|
|
@app.get("/health") |
|
def read_root(): |
|
return {"status": "ok"} |
|
|
|
@app.post("/convert") |
|
def convert(request: ConvertRequest)->FileResponse: |
|
model = AutoAWQForCausalLM.from_pretrained(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
raise HTTPException(status_code=501, detail="Not Implemented yet") |
|
|