Spaces:
Sleeping
Sleeping
import uvicorn | |
import json | |
from ctransformers import AutoModelForCausalLM | |
from fastapi import FastAPI, Form | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
from typing import List, Dict | |
from fastapi.middleware.cors import CORSMiddleware | |
#Model loading | |
model = AutoModelForCausalLM.from_pretrained("vigostral-7b-chat.Q6_K.gguf", | |
model_type='llama', | |
threads = 3, | |
) | |
#Fast API | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins = ["*"], | |
allow_credentials = True, | |
allow_methods = ["*"], | |
allow_headers = ["*"], | |
) | |
def apply_chat_template(conversation: List): | |
formatted_conversation = "" | |
for turn in conversation: | |
role = turn.role.upper() | |
content = turn.content | |
if role == "SYSTEM": | |
formatted_conversation += "<s>[INST] <<SYS>>\n" + content + "\n<</SYS>>" | |
elif role == "ASSISTANT": | |
formatted_conversation += "\n[/INST] " + content + " </s>" | |
else: | |
formatted_conversation += "[INST] " + content + " [/INST]" | |
return formatted_conversation | |
#Pydantic object | |
class Message(BaseModel): | |
role: str | |
content: str | |
class Validation(BaseModel): | |
messages: List[Message] | |
model: str | |
temperature: float | |
presence_penalty: float | |
top_p: float | |
frequency_penalty: float | |
stream: bool | |
async def stream(item: Validation): | |
prompt = apply_chat_template(item.messages) | |
def stream_json(): | |
for text in model( | |
prompt, | |
temperature=item.temperature, | |
top_p=item.top_p, | |
presence_penalty=item.presence_penalty, | |
frequency_penalty=item.frequency_penalty, | |
stream=True | |
): | |
yield json.dumps({ | |
"object":"chat.completion.chunk", | |
"choices": [ | |
{ | |
"index": 0, | |
"delta" : { | |
"content": text | |
} | |
} | |
] | |
}) | |
return StreamingResponse(stream_json(), media_type="application/json") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |