Spaces:
Running
Running
# from https://huggingface.co/spaces/iiced/mixtral-46.7b-fastapi/blob/main/main.py | |
# example of use: | |
# curl -X POST \ | |
# -H "Content-Type: application/json" \ | |
# -d '{ | |
# "prompt": "What is the capital of France?", | |
# "history": [], | |
# "system_prompt": "You are a very powerful AI assistant." | |
# }' \ | |
# https://phk0-bai.hf.space/generate/ | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import uvicorn | |
import torch | |
app = FastAPI() | |
class Item(BaseModel): | |
prompt: str | |
history: list | |
system_prompt: str | |
temperature: float = 0.0 | |
max_new_tokens: int = 900 | |
top_p: float = 0.15 | |
repetition_penalty: float = 1.0 | |
def format_prompt(system, message, history): | |
prompt = [{"role": "system", "content": system}] | |
for user_prompt, bot_response in history: | |
prompt += {"role": "user", "content": user_prompt} | |
prompt += {"role": "assistant", "content": bot_response} | |
prompt += {"role": "user", "content": message} | |
return prompt | |
def generate(item: Item): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_path = "ibm-granite/granite-34b-code-instruct-8k" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# drop device_map if running on CPU | |
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device) | |
model.eval() | |
# change input text as desired | |
chat = format_prompt(item.system_prompt, item.prompt, item.history) | |
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
# tokenize the text | |
input_tokens = tokenizer(chat, return_tensors="pt") | |
# transfer tokenized inputs to the device | |
for i in input_tokens: | |
input_tokens[i] = input_tokens[i].to(device) | |
# generate output tokens | |
output = model.generate(**input_tokens, max_new_tokens=900) | |
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
return output_text | |
async def generate_text(item: Item): | |
return {"response": generate(item)} | |
async def generate_text_root(item: Item): | |
return {"response": "try entry point: /generate/"} | |