import zipfile from typing import Union from awq import AutoAWQForCausalLM from transformers import AutoTokenizer from tempfile import NamedTemporaryFile from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse, FileResponse from .dto import AWQConvertionRequest, GGUFConvertionRequest, GPTQConvertionRequest ### FastAPI Initialization @asynccontextmanager async def lifespan(app:FastAPI): yield app = FastAPI(title="Huggingface Safetensor Model Converter to AWQ", version="0.1.0", lifespan=lifespan) ### ------- @app.get("/", include_in_schema=False) def redirect_to_docs(): return RedirectResponse(url='/docs') ### FastAPI Endpoints @app.post("/convert_awq", response_model=None) def convert(request: AWQConvertionRequest)->Union[FileResponse, dict]: try: model = AutoAWQForCausalLM.from_pretrained(request.hf_model_name) except TypeError as e: raise HTTPException(status_code=400, detail=f"Is this model supported by AWQ Quantization? Check:https://github.com/mit-han-lab/llm-awq?tab=readme-ov-file {e}") tokenizer = AutoTokenizer.from_pretrained(request.hf_tokenizer_name or request.hf_model_name, trust_remote_code=True) model.quantize(tokenizer, quant_config=request.quantization_config.model_dump()) if request.hf_push_repo: model.save_quantized(request.hf_push_repo) tokenizer.save_pretrained(request.hf_push_repo) return { "status": "ok", "message": f"Model saved to {request.hf_push_repo}", } # Return a zip file with the converted model with NamedTemporaryFile(suffix=".zip", delete=False) as temp_zip: zip_file_path = temp_zip.name with zipfile.ZipFile(zip_file_path, 'w') as zipf: # Save the model and tokenizer files to the zip model.save_quantized(zipf) tokenizer.save_pretrained(zipf) return FileResponse( zip_file_path, media_type='application/zip', filename=f"{request.hf_model_name}.zip" ) raise HTTPException(status_code=500, detail="Failed to convert model") @app.post("/convert_gpt_q", response_model=None) def convert_gpt_q(request: GPTQConvertionRequest)->Union[FileResponse, dict]: raise HTTPException(status_code=501, detail="Not implemented yet") @app.post("/convert_gguf", response_model=None) def convert_gguf(request: GGUFConvertionRequest)->Union[FileResponse, dict]: raise HTTPException(status_code=501, detail="Not implemented yet") @app.get("/health") def read_root(): return {"status": "ok"}