Spaces:
Sleeping
Sleeping
File size: 2,235 Bytes
471608b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import uvicorn
import json
from ctransformers import AutoModelForCausalLM
from fastapi import FastAPI, Form
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Dict
from fastapi.middleware.cors import CORSMiddleware
#Model loading
model = AutoModelForCausalLM.from_pretrained("vigostral-7b-chat.Q6_K.gguf",
model_type='llama',
threads = 3,
)
#Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
def apply_chat_template(conversation: List):
formatted_conversation = ""
for turn in conversation:
role = turn.role.upper()
content = turn.content
if role == "SYSTEM":
formatted_conversation += "<s>[INST] <<SYS>>\n" + content + "\n<</SYS>>"
elif role == "ASSISTANT":
formatted_conversation += "\n[/INST] " + content + " </s>"
else:
formatted_conversation += "[INST] " + content + " [/INST]"
return formatted_conversation
#Pydantic object
class Message(BaseModel):
role: str
content: str
class Validation(BaseModel):
messages: List[Message]
model: str
temperature: float
presence_penalty: float
top_p: float
frequency_penalty: float
stream: bool
@app.post("/chat")
async def stream(item: Validation):
prompt = apply_chat_template(item.messages)
def stream_json():
for text in model(
prompt,
temperature=item.temperature,
top_p=item.top_p,
presence_penalty=item.presence_penalty,
frequency_penalty=item.frequency_penalty,
stream=True
):
yield json.dumps({
"object":"chat.completion.chunk",
"choices": [
{
"index": 0,
"delta" : {
"content": text
}
}
]
})
return StreamingResponse(stream_json(), media_type="application/json")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|