File size: 1,766 Bytes
c0be431
 
eb967be
 
c0be431
 
 
 
 
 
 
6bfe382
c0be431
 
 
 
 
6bfe382
c0be431
 
 
 
 
 
6bfe382
c0be431
 
 
 
 
 
6bfe382
 
c0be431
 
 
6bfe382
c0be431
 
 
 
 
 
 
6bfe382
eb967be
 
c0be431
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from dotenv import load_dotenv
from generators import *
import gradio as gr

from utils import async_zip_stream

load_dotenv()


async def handle(system_input: str, user_input: str):
    print(system_input, user_input)
    buffers = ["", "", "", "", ""]
    async for outputs in async_zip_stream(
            generate_gpt2(system_input, user_input),
            generate_mistral_7bvo1(system_input, user_input),
            generate_llama2(system_input, user_input),
            generate_llama3(system_input, user_input),
            generate_mistral_7bvo3(system_input, user_input),
    ):
        # gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
        for i, b in enumerate(buffers):
            buffers[i] += str(outputs[i])

        yield list(buffers) + ["", ""]
    yield list(buffers) + [generate_bloom(system_input, user_input)]


with gr.Blocks() as demo:
    system_input = gr.Textbox(label='System Input', value='You are AI assistant', lines=2)
    with gr.Row():
        gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False)
        mistral = gr.Textbox(label='mistral-v01', lines=4, interactive=False)
        mistral_new = gr.Textbox(label='mistral-v03', lines=4, interactive=False)
    with gr.Row():
        llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False)
        llama3 = gr.Textbox(label='llama-3', lines=4, interactive=False)
        bloom = gr.Textbox(label='bloom [GPU]', lines=4, interactive=False)

    user_input = gr.Textbox(label='User Input', lines=2)
    gen_button = gr.Button('Generate')

    gen_button.click(
        fn=handle,
        inputs=[system_input, user_input],
        outputs=[gpt, mistral, llama2, llama3, mistral_new, bloom],
    )

demo.launch()