misbah1955 commited on
Commit
d67c964
·
1 Parent(s): b6d052e

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +247 -0
  2. model.py +57 -0
  3. requirements.txt +9 -0
  4. style.css +16 -0
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+
4
+ import gradio as gr
5
+
6
+ from model import run
7
+
8
+ HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
9
+
10
+ DEFAULT_SYSTEM_PROMPT = "You are Mistral. You are AI-assistant, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI. You can communicate in different languages equally well."
11
+ MAX_MAX_NEW_TOKENS = 4096
12
+ DEFAULT_MAX_NEW_TOKENS = 256
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
+
15
+ DESCRIPTION = """
16
+ # [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
17
+ """
18
+
19
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
20
+ return '', message
21
+
22
+
23
+ def display_input(message: str,
24
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
25
+ history.append((message, ''))
26
+ return history
27
+
28
+
29
+ def delete_prev_fn(
30
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
31
+ try:
32
+ message, _ = history.pop()
33
+ except IndexError:
34
+ message = ''
35
+ return history, message or ''
36
+
37
+
38
+ def generate(
39
+ message: str,
40
+ history_with_input: list[tuple[str, str]],
41
+ system_prompt: str,
42
+ max_new_tokens: int,
43
+ temperature: float,
44
+ top_p: float,
45
+ top_k: int,
46
+ ) -> Iterator[list[tuple[str, str]]]:
47
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
48
+ raise ValueError
49
+
50
+ history = history_with_input[:-1]
51
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
52
+ try:
53
+ first_response = next(generator)
54
+ yield history + [(message, first_response)]
55
+ except StopIteration:
56
+ yield history + [(message, '')]
57
+ for response in generator:
58
+ yield history + [(message, response)]
59
+
60
+
61
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
62
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
63
+ for x in generator:
64
+ pass
65
+ return '', x
66
+
67
+
68
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
69
+ input_token_length = len(message) + len(chat_history)
70
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
71
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
72
+
73
+
74
+ with gr.Blocks(css='style.css') as demo:
75
+ gr.Markdown(DESCRIPTION)
76
+ gr.DuplicateButton(value='Duplicate Space for private use',
77
+ elem_id='duplicate-button')
78
+
79
+ with gr.Group():
80
+ chatbot = gr.Chatbot(label='Playground')
81
+ with gr.Row():
82
+ textbox = gr.Textbox(
83
+ container=False,
84
+ show_label=False,
85
+ placeholder='Hi, Mistral!',
86
+ scale=10,
87
+ )
88
+ submit_button = gr.Button('Submit',
89
+ variant='primary',
90
+ scale=1,
91
+ min_width=0)
92
+ with gr.Row():
93
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
94
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
95
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
96
+
97
+ saved_input = gr.State()
98
+
99
+ with gr.Accordion(label='⚙️ Advanced options', open=False):
100
+ system_prompt = gr.Textbox(label='System prompt',
101
+ value=DEFAULT_SYSTEM_PROMPT,
102
+ lines=5,
103
+ interactive=False)
104
+ max_new_tokens = gr.Slider(
105
+ label='Max new tokens',
106
+ minimum=1,
107
+ maximum=MAX_MAX_NEW_TOKENS,
108
+ step=1,
109
+ value=DEFAULT_MAX_NEW_TOKENS,
110
+ )
111
+ temperature = gr.Slider(
112
+ label='Temperature',
113
+ minimum=0.1,
114
+ maximum=4.0,
115
+ step=0.1,
116
+ value=0.1,
117
+ )
118
+ top_p = gr.Slider(
119
+ label='Top-p (nucleus sampling)',
120
+ minimum=0.05,
121
+ maximum=1.0,
122
+ step=0.05,
123
+ value=0.9,
124
+ )
125
+ top_k = gr.Slider(
126
+ label='Top-k',
127
+ minimum=1,
128
+ maximum=1000,
129
+ step=1,
130
+ value=10,
131
+ )
132
+
133
+
134
+
135
+ textbox.submit(
136
+ fn=clear_and_save_textbox,
137
+ inputs=textbox,
138
+ outputs=[textbox, saved_input],
139
+ api_name=False,
140
+ queue=False,
141
+ ).then(
142
+ fn=display_input,
143
+ inputs=[saved_input, chatbot],
144
+ outputs=chatbot,
145
+ api_name=False,
146
+ queue=False,
147
+ ).then(
148
+ fn=check_input_token_length,
149
+ inputs=[saved_input, chatbot, system_prompt],
150
+ api_name=False,
151
+ queue=False,
152
+ ).success(
153
+ fn=generate,
154
+ inputs=[
155
+ saved_input,
156
+ chatbot,
157
+ system_prompt,
158
+ max_new_tokens,
159
+ temperature,
160
+ top_p,
161
+ top_k,
162
+ ],
163
+ outputs=chatbot,
164
+ api_name=False,
165
+ )
166
+
167
+ button_event_preprocess = submit_button.click(
168
+ fn=clear_and_save_textbox,
169
+ inputs=textbox,
170
+ outputs=[textbox, saved_input],
171
+ api_name=False,
172
+ queue=False,
173
+ ).then(
174
+ fn=display_input,
175
+ inputs=[saved_input, chatbot],
176
+ outputs=chatbot,
177
+ api_name=False,
178
+ queue=False,
179
+ ).then(
180
+ fn=check_input_token_length,
181
+ inputs=[saved_input, chatbot, system_prompt],
182
+ api_name=False,
183
+ queue=False,
184
+ ).success(
185
+ fn=generate,
186
+ inputs=[
187
+ saved_input,
188
+ chatbot,
189
+ system_prompt,
190
+ max_new_tokens,
191
+ temperature,
192
+ top_p,
193
+ top_k,
194
+ ],
195
+ outputs=chatbot,
196
+ api_name=False,
197
+ )
198
+
199
+ retry_button.click(
200
+ fn=delete_prev_fn,
201
+ inputs=chatbot,
202
+ outputs=[chatbot, saved_input],
203
+ api_name=False,
204
+ queue=False,
205
+ ).then(
206
+ fn=display_input,
207
+ inputs=[saved_input, chatbot],
208
+ outputs=chatbot,
209
+ api_name=False,
210
+ queue=False,
211
+ ).then(
212
+ fn=generate,
213
+ inputs=[
214
+ saved_input,
215
+ chatbot,
216
+ system_prompt,
217
+ max_new_tokens,
218
+ temperature,
219
+ top_p,
220
+ top_k,
221
+ ],
222
+ outputs=chatbot,
223
+ api_name=False,
224
+ )
225
+
226
+ undo_button.click(
227
+ fn=delete_prev_fn,
228
+ inputs=chatbot,
229
+ outputs=[chatbot, saved_input],
230
+ api_name=False,
231
+ queue=False,
232
+ ).then(
233
+ fn=lambda x: x,
234
+ inputs=[saved_input],
235
+ outputs=textbox,
236
+ api_name=False,
237
+ queue=False,
238
+ )
239
+
240
+ clear_button.click(
241
+ fn=lambda: ([], ''),
242
+ outputs=[chatbot, saved_input],
243
+ queue=False,
244
+ api_name=False,
245
+ )
246
+
247
+ demo.queue(max_size=32).launch(share=HF_PUBLIC, show_api=False)
model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+
4
+ from text_generation import Client
5
+
6
+ model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
7
+
8
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN", None)
10
+
11
+ client = Client(
12
+ API_URL,
13
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
14
+ )
15
+ EOS_STRING = "</s>"
16
+ EOT_STRING = "<EOT>"
17
+
18
+
19
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
20
+ system_prompt: str) -> str:
21
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
22
+ # The first user input is _not_ stripped
23
+ do_strip = False
24
+ for user_input, response in chat_history:
25
+ user_input = user_input.strip() if do_strip else user_input
26
+ do_strip = True
27
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
28
+ message = message.strip() if do_strip else message
29
+ texts.append(f'{message} [/INST]')
30
+ return ''.join(texts)
31
+
32
+
33
+ def run(message: str,
34
+ chat_history: list[tuple[str, str]],
35
+ system_prompt: str,
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.1,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50) -> Iterator[str]:
40
+ prompt = get_prompt(message, chat_history, system_prompt)
41
+
42
+ generate_kwargs = dict(
43
+ max_new_tokens=max_new_tokens,
44
+ do_sample=True,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ temperature=temperature,
48
+ )
49
+ stream = client.generate_stream(prompt, **generate_kwargs)
50
+ output = ""
51
+ for response in stream:
52
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
53
+ return output
54
+ else:
55
+ output += response.token.text
56
+ yield output
57
+ return output
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes
3
+ gradio
4
+ protobuf
5
+ scipy
6
+ sentencepiece
7
+ torch
8
+ text_generation
9
+ git+https://github.com/huggingface/transformers@main
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ #component-0 {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }