File size: 4,398 Bytes
af9aed3
0735f93
af9aed3
5e9b3af
d75b820
5e9b3af
6af49e3
d75b820
af9aed3
 
d75b820
5e9b3af
d75b820
 
 
 
 
5e9b3af
d75b820
5e9b3af
0735f93
 
 
5e9b3af
af9aed3
5e9b3af
 
 
 
0735f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e9b3af
 
 
 
 
 
0735f93
 
af9aed3
 
0735f93
 
 
 
 
 
af9aed3
6af49e3
 
af9aed3
 
 
6af49e3
af9aed3
 
6af49e3
 
 
 
 
 
 
 
 
 
 
 
 
af9aed3
d75b820
0735f93
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import zipfile
from abc import ABC
from typing import Optional, Union
from awq import AutoAWQForCausalLM
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
from tempfile import NamedTemporaryFile
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse, FileResponse

### FastAPI Initialization
@asynccontextmanager
async def lifespan(app:FastAPI):    
    yield

app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
### -------

### DTO Definitions
class QuantizationConfig(ABC, BaseModel):
    pass
class ConvertRequest(ABC, BaseModel):
    hf_model_name: str
    hf_tokenizer_name: Optional[str] = Field(None, description="Hugging Face tokenizer name. Defaults to hf_model_name")
    hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
    hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
### -------

### Quantization Configurations
class AWQQuantizationConfig(QuantizationConfig):
    zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
    q_group_size: Optional[int] = Field(128, description="Quantization group size")
    w_bit: Optional[int] = Field(4, description="Weight bit")
    version: Optional[str] = Field("GEMM", description="Quantization version")

class GPTQQuantizationConfig(QuantizationConfig):
    pass

class GGUFQuantizationConfig(QuantizationConfig):
    pass
class AWQConvertionRequest(ConvertRequest):
    quantization_config: Optional[AWQQuantizationConfig] = Field(
        default_factory=lambda: AWQQuantizationConfig(),
        description="AWQ quantization configuration"
    )

class GPTQConvertionRequest(ConvertRequest):
    quantization_config: Optional[GPTQQuantizationConfig] = Field(
        default_factory=lambda: GPTQQuantizationConfig(),
        description="GPTQ quantization configuration"
    )

class GGUFConvertionRequest(ConvertRequest):
    quantization_config: Optional[GGUFQuantizationConfig] = Field(
        default_factory=lambda: GGUFQuantizationConfig(),
        description="GGUF quantization configuration"
    )
### -------

@app.get("/", include_in_schema=False)
def redirect_to_docs():
    return RedirectResponse(url='/docs')

### FastAPI Endpoints
@app.post("/convert_awq", response_model=None)
def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]:
    model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name)
    tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True)
    
    try:
        model.quantize(tokenizer, quant_config=request.quantization_config.model_dump())
    except TypeError as e:
        raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}")
    
    if request.hf_push_repo:
        model.save_quantized(request.hf_push_repo)
        tokenizer.save_pretrained(request.hf_push_repo)

        return {
            "status": "ok",
            "message": f"Model saved to {request.hf_push_repo}",
        }
    
    # Return a zip file with the converted model
    with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip:
        zip_file_path = temp_zip.name
        with zipfile.ZipFile(zip_file_path, 'w') as zipf:
            # Save the model and tokenizer files to the zip
            model.save_quantized(zipf)
            tokenizer.save_pretrained(zipf)

            return FileResponse(
                zip_file_path,
                media_type='application/zip',
                filename=f"{request.hf_model_name}.zip"
            )
        
    
    raise HTTPException(status_code=500, detail="Failed to convert model")

@app.post("/convert_gpt_q", response_model=None)
def convert_gpt_q(request: ConvertRequest)->Union[FileResponse, dict]:
    raise HTTPException(status_code=501, detail="Not implemented yet")

@app.post("/convert_gguf", response_model=None)
def convert_gguf(request: ConvertRequest)->Union[FileResponse, dict]:
    raise HTTPException(status_code=501, detail="Not implemented yet")

@app.get("/health")
def read_root():
    return {"status": "ok"}