Chris-lab / pages /arena.py
kz209
update
9dfac6e
raw
history blame
1.52 kB
#from utils.multiple_stream import create_interface
import random
import gradio as gr
import json
import logging
import gc
import torch
from utils.data import dataset
from utils.multiple_stream import stream_data
from pages.summarization_playground import get_model_batch_generation
def create_arena():
with open("prompt/prompt.json", "r") as file:
json_data = file.read()
prompts = json.loads(json_data)
with gr.Blocks() as demo:
with gr.Group():
datapoint = random.choice(dataset)
datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
submit_button = gr.Button("✨ Submit ✨")
with gr.Row():
columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(prompts))]
content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts]
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
def start_streaming():
for data in stream_data(content_list, model):
updates = [gr.update(value=data[i]) for i in range(len(columns))]
yield tuple(updates)
submit_button.click(
fn=start_streaming,
inputs=[],
outputs=columns,
show_progress=False
)
return demo
if __name__ == "__main__":
demo = create_arena()
demo.queue()
demo.launch()