dioarafl commited on
Commit
68b7f30
·
verified ·
1 Parent(s): 092e3c8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ import edge_tts
5
+ import asyncio
6
+ import time
7
+ import tempfile
8
+ from huggingface_hub import InferenceClient
9
+
10
+ class JarvisModels:
11
+ def __init__(self):
12
+ self.client1 = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
+ self.client2 = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
14
+ self.client3 = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
15
+
16
+ async def generate_model1(self, prompt):
17
+ generate_kwargs = dict(
18
+ temperature=0.6,
19
+ max_new_tokens=256,
20
+ top_p=0.95,
21
+ repetition_penalty=1,
22
+ do_sample=True,
23
+ seed=42,
24
+ )
25
+ formatted_prompt = system_instructions1 + prompt + "[JARVIS]"
26
+ stream = self.client1.text_generation(
27
+ formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
28
+ output = ""
29
+ for response in stream:
30
+ output += response.token.text
31
+
32
+ communicate = edge_tts.Communicate(output)
33
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
34
+ tmp_path = tmp_file.name
35
+ await communicate.save(tmp_path)
36
+ yield tmp_path
37
+
38
+ async def generate_model2(self, prompt):
39
+ generate_kwargs = dict(
40
+ temperature=0.6,
41
+ max_new_tokens=512,
42
+ top_p=0.95,
43
+ repetition_penalty=1,
44
+ do_sample=True,
45
+ )
46
+ formatted_prompt = system_instructions2 + prompt + "[ASSISTANT]"
47
+ stream = self.client2.text_generation(
48
+ formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
49
+ output = ""
50
+ for response in stream:
51
+ output += response.token.text
52
+
53
+ communicate = edge_tts.Communicate(output)
54
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
55
+ tmp_path = tmp_file.name
56
+ await communicate.save(tmp_path)
57
+ yield tmp_path
58
+
59
+ async def generate_model3(self, prompt):
60
+ generate_kwargs = dict(
61
+ temperature=0.6,
62
+ max_new_tokens=2048,
63
+ top_p=0.95,
64
+ repetition_penalty=1,
65
+ do_sample=True,
66
+ )
67
+ formatted_prompt = system_instructions3 + prompt + "[ASSISTANT]"
68
+ stream = self.client3.text_generation(
69
+ formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
70
+ output = ""
71
+ for response in stream:
72
+ output += response.token.text
73
+
74
+ communicate = edge_tts.Communicate(output)
75
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
76
+ tmp_path = tmp_file.name
77
+ await communicate.save(tmp_path)
78
+ yield tmp_path
79
+
80
+ class JarvisApp:
81
+ def __init__(self):
82
+ self.models = JarvisModels()
83
+
84
+ def launch_app(self):
85
+ with gr.Blocks(css="style.css") as demo:
86
+ gr.Markdown(DESCRIPTION)
87
+ with gr.Row():
88
+ user_input = gr.Textbox(label="Prompt", value="What is Wikipedia")
89
+ input_text = gr.Textbox(label="Input Text", elem_id="important")
90
+ output_audio = gr.Audio(label="JARVIS", type="filepath",
91
+ interactive=False,
92
+ autoplay=True,
93
+ elem_classes="audio")
94
+ with gr.Row():
95
+ translate_btn = gr.Button("Response")
96
+ translate_btn.click(fn=self.models.generate_model1, inputs=user_input,
97
+ outputs=output_audio, api_name="translate")
98
+
99
+ gr.Markdown(MORE)
100
+
101
+ if __name__ == "__main__":
102
+ demo.queue(max_size=200).launch()
103
+
104
+ if __name__ == "__main__":
105
+ app = JarvisApp()
106
+ app.launch_app()
107
+