vigostral-chat / main.py
ilhooq's picture
Load model directly from huggingface
0b27b0c
import uvicorn
import json
from ctransformers import AutoModelForCausalLM
from fastapi import FastAPI, Form
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Dict
from fastapi.middleware.cors import CORSMiddleware
#Model loading
model = AutoModelForCausalLM.from_pretrained("TheBloke/Vigostral-7B-Chat-GGUF",
model_file="vigostral-7b-chat.Q4_K_M.gguf",
model_type="mistral",
threads = 3
)
#Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
def apply_chat_template(conversation: List):
formatted_conversation = ""
for turn in conversation:
role = turn.role.upper()
content = turn.content
if role == "SYSTEM":
formatted_conversation += "<s>[INST] <<SYS>>\n" + content + "\n<</SYS>>"
elif role == "ASSISTANT":
formatted_conversation += "\n[/INST] " + content + " </s>"
else:
formatted_conversation += "[INST] " + content + " [/INST]"
return formatted_conversation
#Pydantic object
class Message(BaseModel):
role: str
content: str
class Validation(BaseModel):
messages: List[Message]
model: str
temperature: float
presence_penalty: float
top_p: float
frequency_penalty: float
stream: bool
@app.post("/chat")
async def stream(item: Validation):
prompt = apply_chat_template(item.messages)
def stream_json():
for text in model(
prompt,
temperature=item.temperature, # Default to 0.8
top_p=item.top_p, # Default to 0.95
repetition_penalty=item.frequency_penalty, # Default to 1.1
stream=True
):
yield json.dumps({
"object":"chat.completion.chunk",
"choices": [
{
"index": 0,
"delta" : {
"content": text
}
}
]
})
return StreamingResponse(stream_json(), media_type="application/json")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)