Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket | |
from fastapi.responses import HTMLResponse | |
from fastapi import Form, Depends, HTTPException, status | |
from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
import time | |
import re | |
import json | |
app = FastAPI() | |
html = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Chat</title> | |
</head> | |
<body> | |
<h1>WebSocket Chat</h1> | |
<form action="" onsubmit="sendMessage(event)"> | |
<input type="text" id="messageText" autocomplete="off"/> | |
<button>Send</button> | |
</form> | |
<ul id='messages'> | |
</ul> | |
<script> | |
// var ws = new WebSocket("ws://localhost:8000/api/ws"); | |
var ws = new WebSocket("wss://cahya-indonesian-whisperer.hf.space/api/ws"); | |
ws.onmessage = function(event) { | |
var messages = document.getElementById('messages') | |
var message = document.createElement('li') | |
var content = document.createTextNode(event.data) | |
message.appendChild(content) | |
messages.appendChild(message) | |
}; | |
function sendMessage(event) { | |
var input = document.getElementById("messageText") | |
ws.send(input.value) | |
input.value = '' | |
event.preventDefault() | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
async def get(): | |
return HTMLResponse(html) | |
async def env(): | |
environment_variables = "<h3>Environment Variables</h3>" | |
for name, value in os.environ.items(): | |
environment_variables += f"{name}: {value}<br>" | |
return HTMLResponse(environment_variables) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
while True: | |
data = await websocket.receive_text() | |
await websocket.send_text(f"Message text was: {data}") | |
async def indochat(**kwargs): | |
return text_generate("indochat-tiny", kwargs) | |
async def text_generate( | |
model_name: str = Form(default="", description="The model name"), | |
text: str = Form(default="", description="The Prompt"), | |
decoding_method: str = Form(default="Sampling", description="Decoding method"), | |
min_length: int = Form(default=50, description="Minimal length of the generated text"), | |
max_length: int = Form(default=250, description="Maximal length of the generated text"), | |
num_beams: int = Form(default=5, description="Beams number"), | |
top_k: int = Form(default=30, description="The number of highest probability vocabulary tokens to keep " | |
"for top-k-filtering"), | |
top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with " | |
"probabilities that add up to top_p or higher are kept " | |
"for generation"), | |
temperature: float = Form(default=0.5, description="The Temperature of the softmax distribution"), | |
penalty_alpha: float = Form(default=0.5, description="Penalty alpha"), | |
repetition_penalty: float = Form(default=1.2, description="Repetition penalty"), | |
seed: int = Form(default=-1, description="Random Seed"), | |
max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text") | |
): | |
if seed >= 0: | |
set_seed(seed) | |
if decoding_method == "Beam Search": | |
do_sample = False | |
penalty_alpha = 0 | |
elif decoding_method == "Sampling": | |
do_sample = True | |
penalty_alpha = 0 | |
num_beams = 1 | |
else: | |
do_sample = False | |
num_beams = 1 | |
if repetition_penalty == 0.0: | |
min_penalty = 1.05 | |
max_penalty = 1.5 | |
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8) | |
prompt = f"User: {text}\nAssistant: " | |
input_ids = text_generator[model_name]["tokenizer"](prompt, return_tensors='pt').input_ids.to(device) | |
text_generator[model_name]["model"].eval() | |
print("Generating text...") | |
print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, " | |
f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}") | |
time_start = time.time() | |
sample_outputs = text_generator[model_name]["model"].generate(input_ids, | |
penalty_alpha=penalty_alpha, | |
do_sample=do_sample, | |
num_beams=num_beams, | |
min_length=min_length, | |
max_length=max_length, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
num_return_sequences=1, | |
max_time=max_time | |
) | |
result = text_generator[model_name]["tokenizer"].decode(sample_outputs[0], skip_special_tokens=True) | |
time_end = time.time() | |
time_diff = time_end - time_start | |
print(f"result:\n{result}") | |
generated_text = result[len(prompt)+1:] | |
generated_text = generated_text[:generated_text.find("User:")] | |
return {"generated_text": generated_text, "processing_time": time_diff} | |
def get_text_generator(model_name: str, device: str = "cpu"): | |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False) | |
print(f"hf_auth_token: {hf_auth_token}") | |
print(f"Loading model with device: {device}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token) | |
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, | |
use_auth_token=hf_auth_token) | |
model.to(device) | |
print("Model loaded") | |
return model, tokenizer | |
def get_config(): | |
return json.load(open("config.json", "r")) | |
config = get_config() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
text_generator = {} | |
for model_name in config["text-generator"]: | |
model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name], device=device) | |
text_generator[model_name] = { | |
"model": model, | |
"tokenizer": tokenizer | |
} | |