Spaces:
Sleeping
Sleeping
import os | |
import spaces | |
from threading import Thread | |
from typing import Iterator, List, Tuple | |
import json | |
import requests | |
import gradio as gr | |
import torch | |
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
# Description for the Gradio Interface | |
DESCRIPTION = """\ | |
# Zero GPU Model Comparison Arena | |
Select two different models from the dropdowns and see how they perform on the same input. | |
""" | |
# Constants | |
MAX_MAX_NEW_TOKENS = 256 | |
DEFAULT_MAX_NEW_TOKENS = 128 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
# Device configuration | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Model options | |
MODEL_OPTIONS = [ | |
"sarvamai/OpenHathi-7B-Hi-v0.1-Base", | |
"TokenBender/Navarna_v0_1_OpenHermes_Hindi" | |
] | |
# Load models and tokenizers | |
models = {} | |
tokenizers = {} | |
for model_id in MODEL_OPTIONS: | |
tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id) | |
models[model_id] = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
load_in_8bit=True, | |
) | |
models[model_id].eval() | |
# Set pad_token_id to eos_token_id if it's not set | |
if tokenizers[model_id].pad_token_id is None: | |
tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id | |
# Function to log comparisons | |
def log_comparison(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None): | |
log_data = { | |
"question": question, | |
"model1": {"name": model1_name, "answer": answer1}, | |
"model2": {"name": model2_name, "answer": answer2}, | |
"winner": winner | |
} | |
# Send log data to remote server | |
try: | |
response = requests.post('http://144.24.151.32:5000/log', json=log_data, timeout=5) | |
if response.status_code == 200: | |
print("Successfully logged to server") | |
else: | |
print(f"Failed to log to server. Status code: {response.status_code}") | |
except requests.RequestException as e: | |
print(f"Error sending log to server: {e}") | |
# Function to prepare input | |
def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]): | |
tokenizer = tokenizers[model_id] | |
# Prepare inputs for the model | |
inputs = tokenizer( | |
[x[1] for x in chat_history] + [message], | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=MAX_INPUT_TOKEN_LENGTH, | |
) | |
return inputs | |
# Function to generate responses from models | |
def generate( | |
model_id: str, | |
message: str, | |
chat_history: List[Tuple[str, str]], | |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
temperature: float = 0.4, | |
top_p: float = 0.95, | |
) -> Iterator[str]: | |
model = models[model_id] | |
tokenizer = tokenizers[model_id] | |
inputs = prepare_input(model_id, message, chat_history) | |
input_ids = inputs.input_ids | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=1, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
# Function to compare two models | |
def compare_models( | |
model1_name: str, | |
model2_name: str, | |
message: str, | |
chat_history1: List[Tuple[str, str]], | |
chat_history2: List[Tuple[str, str]], | |
max_new_tokens: int, | |
temperature: float, | |
top_p: float, | |
) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]: | |
if model1_name == model2_name: | |
error_message = [("System", "Error: Please select two different models.")] | |
return error_message, error_message, chat_history1, chat_history2 | |
output1 = "".join(list(generate(model1_name, message, chat_history1, max_new_tokens, temperature, top_p))) | |
output2 = "".join(list(generate(model2_name, message, chat_history2, max_new_tokens, temperature, top_p))) | |
chat_history1.append((message, output1)) | |
chat_history2.append((message, output2)) | |
log_comparison(model1_name, model2_name, message, output1, output2) | |
return chat_history1, chat_history2, chat_history1, chat_history2 | |
# Function to log the voting result | |
def vote_better(model1_name, model2_name, question, answer1, answer2, choice): | |
winner = model1_name if choice == "Model 1" else model2_name | |
log_comparison(model1_name, model2_name, question, answer1, answer2, winner) | |
return f"You voted that {winner} performs better. This has been logged." | |
# Gradio UI setup | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
model1_dropdown = gr.Dropdown(choices=MODEL_OPTIONS, label="Model 1", value=MODEL_OPTIONS[0]) | |
chatbot1 = gr.Chatbot(label="Model 1 Output") | |
with gr.Column(): | |
model2_dropdown = gr.Dropdown(choices=MODEL_OPTIONS, label="Model 2", value=MODEL_OPTIONS[1]) | |
chatbot2 = gr.Chatbot(label="Model 2 Output") | |
text_input = gr.Textbox(label="Input Text", lines=3) | |
with gr.Row(): | |
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7) | |
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, value=0.95) | |
compare_btn = gr.Button("Compare Models") | |
with gr.Row(): | |
better1_btn = gr.Button("Model 1 is Better") | |
better2_btn = gr.Button("Model 2 is Better") | |
vote_output = gr.Textbox(label="Voting Result") | |
compare_btn.click( | |
compare_models, | |
inputs=[model1_dropdown, model2_dropdown, text_input, chatbot1, chatbot2, max_new_tokens, temperature, top_p], | |
outputs=[chatbot1, chatbot2, chatbot1, chatbot2] | |
) | |
better1_btn.click( | |
vote_better, | |
inputs=[model1_dropdown, model2_dropdown, text_input, chatbot1, chatbot2, gr.Textbox(value="Model 1", visible=False)], | |
outputs=[vote_output] | |
) | |
better2_btn.click( | |
vote_better, | |
inputs=[model1_dropdown, model2_dropdown, text_input, chatbot1, chatbot2, gr.Textbox(value="Model 2", visible=False)], | |
outputs=[vote_output] | |
) | |
# Main function to run the Gradio app | |
if __name__ == "__main__": | |
demo.queue(max_size=3).launch(share=True) | |