Rúben Almeida
commited on
Commit
·
5e9b3af
1
Parent(s):
288d5ce
Add Default redirect to swagger ui
Browse files- Dockerfile +1 -1
- main.py +28 -7
Dockerfile
CHANGED
@@ -24,4 +24,4 @@ COPY . .
|
|
24 |
|
25 |
EXPOSE 7860
|
26 |
|
27 |
-
ENTRYPOINT [
|
|
|
24 |
|
25 |
EXPOSE 7860
|
26 |
|
27 |
+
ENTRYPOINT ["fastapi", "run", "main.py", "--host=0.0.0.0", "--port=7860"]
|
main.py
CHANGED
@@ -1,26 +1,47 @@
|
|
1 |
from typing import Optional
|
|
|
2 |
from pydantic import BaseModel, Field
|
3 |
-
from
|
4 |
from contextlib import asynccontextmanager
|
5 |
from starlette.responses import FileResponse
|
|
|
6 |
|
7 |
-
|
8 |
-
hf_model_name: str
|
9 |
-
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
|
10 |
-
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model")
|
11 |
-
|
12 |
@asynccontextmanager
|
13 |
async def lifespan(app:FastAPI):
|
14 |
yield
|
15 |
|
16 |
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
@app.get("/health")
|
19 |
def read_root():
|
20 |
return {"status": "ok"}
|
21 |
|
22 |
@app.post("/convert")
|
23 |
def convert(request: ConvertRequest)->FileResponse:
|
24 |
-
|
|
|
25 |
|
|
|
26 |
#return FileResponse(file_location, media_type='application/octet-stream',filename=file_name)
|
|
|
1 |
from typing import Optional
|
2 |
+
from awq import AutoAWQForCausalLM
|
3 |
from pydantic import BaseModel, Field
|
4 |
+
from transformers import AutoTokenizer
|
5 |
from contextlib import asynccontextmanager
|
6 |
from starlette.responses import FileResponse
|
7 |
+
from fastapi import FastAPI, HTTPException, RedirectResponse
|
8 |
|
9 |
+
### FastAPI Initialization
|
|
|
|
|
|
|
|
|
10 |
@asynccontextmanager
|
11 |
async def lifespan(app:FastAPI):
|
12 |
yield
|
13 |
|
14 |
app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan)
|
15 |
+
### -------
|
16 |
|
17 |
+
### DTO Definitions
|
18 |
+
class QuantizationConfig(BaseModel):
|
19 |
+
zero_point: Optional[bool] = Field(True, description="Use zero point quantization")
|
20 |
+
q_group_size: Optional[int] = Field(128, description="Quantization group size")
|
21 |
+
w_bit: Optional[int] = Field(4, description="Weight bit")
|
22 |
+
version: Optional[str] = Field("GEMM", description="Quantization version")
|
23 |
+
|
24 |
+
class ConvertRequest(BaseModel):
|
25 |
+
hf_model_name: str
|
26 |
+
hf_token: Optional[str] = Field(None, description="Hugging Face token for private models")
|
27 |
+
hf_push_repo: Optional[str] = Field(None, description="Hugging Face repo to push the converted model. If not provided, the model will be downloaded only.")
|
28 |
+
quantization_config: QuantizationConfig = Field(QuantizationConfig(), description="Quantization configuration")
|
29 |
+
### -------
|
30 |
+
|
31 |
+
|
32 |
+
@app.get("/", include_in_schema=False)
|
33 |
+
def redirect_to_docs():
|
34 |
+
return RedirectResponse(url='/docs')
|
35 |
+
|
36 |
+
### FastAPI Endpoints
|
37 |
@app.get("/health")
|
38 |
def read_root():
|
39 |
return {"status": "ok"}
|
40 |
|
41 |
@app.post("/convert")
|
42 |
def convert(request: ConvertRequest)->FileResponse:
|
43 |
+
model = AutoAWQForCausalLM.from_pretrained(model_path)
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
45 |
|
46 |
+
raise HTTPException(status_code=501, detail="Not Implemented yet")
|
47 |
#return FileResponse(file_location, media_type='application/octet-stream',filename=file_name)
|