Zhanming commited on
Commit
7c8e1f1
·
1 Parent(s): d930631

Implement Gradio interface and update dependencies

Browse files
Files changed (4) hide show
  1. .idea/misc.xml +3 -0
  2. Dockerfile +5 -10
  3. app.py +253 -5
  4. requirements.txt +7 -2
.idea/misc.xml CHANGED
@@ -1,5 +1,8 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
 
 
 
3
  <component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="corretto-21" project-jdk-type="JavaSDK">
4
  <output url="file://$PROJECT_DIR$/out" />
5
  </component>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12 (python-312)" />
5
+ </component>
6
  <component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="corretto-21" project-jdk-type="JavaSDK">
7
  <output url="file://$PROJECT_DIR$/out" />
8
  </component>
Dockerfile CHANGED
@@ -1,13 +1,8 @@
1
  FROM python:3.9
2
-
3
  RUN useradd -m -u 1000 user
 
 
 
 
4
  USER user
5
- ENV PATH="/home/user/.local/bin:${PATH}"
6
-
7
- WORKDIR /app
8
-
9
- COPY --chown=user ./requirements.txt requirements.txt
10
- RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
-
12
- COPY --chown=user . /app
13
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.9
 
2
  RUN useradd -m -u 1000 user
3
+ WORKDIR /code
4
+ COPY ./requirements.txt /code/requirements.txt
5
+ RUN pip install --upgrade pip
6
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
7
  USER user
8
+ COPY --link --chown=1000 ./ /code
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,6 +1,254 @@
1
- from fastapi import FastAPI
2
- app = FastAPI()
3
 
4
- @app.get("/")
5
- def greet_json():
6
- return {"message": "Hello World"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+ import gradio as gr
3
 
4
+ from transformers.utils import logging
5
+ from model import get_input_token_length, run
6
+
7
+ logging.set_verbosity_info()
8
+ logger = logging.get_logger("transformers")
9
+
10
+ DEFAULT_SYSTEM_PROMPT = """"""
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
+
15
+ DESCRIPTION = """"""
16
+
17
+ LICENSE = """"""
18
+
19
+ logger.info("Starting")
20
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
21
+ return '', message
22
+
23
+
24
+ def display_input(message: str,
25
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
26
+ history.append((message, ''))
27
+ logger.info("display_input=%s",message)
28
+ return history
29
+
30
+
31
+ def delete_prev_fn(
32
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
33
+ try:
34
+ message, _ = history.pop()
35
+ except IndexError:
36
+ message = ''
37
+ return history, message or ''
38
+
39
+
40
+ def generate(
41
+ message: str,
42
+ history_with_input: list[tuple[str, str]],
43
+ system_prompt: str,
44
+ max_new_tokens: int,
45
+ temperature: float,
46
+ top_p: float,
47
+ top_k: int,
48
+ ) -> Iterator[list[tuple[str, str]]]:
49
+ #logger.info("message=%s",message)
50
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
51
+ raise ValueError
52
+
53
+ history = history_with_input[:-1]
54
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
55
+ try:
56
+ first_response = next(generator)
57
+ yield history + [(message, first_response)]
58
+ except StopIteration:
59
+ yield history + [(message, '')]
60
+ for response in generator:
61
+ yield history + [(message, response)]
62
+
63
+
64
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
65
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
66
+ for x in generator:
67
+ pass
68
+ return '', x
69
+
70
+
71
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
72
+ #logger.info("check_input_token_length=%s",message)
73
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
74
+ #logger.info("input_token_length",input_token_length)
75
+ #logger.info("MAX_INPUT_TOKEN_LENGTH",MAX_INPUT_TOKEN_LENGTH)
76
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
77
+ logger.info("Inside IF condition")
78
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
79
+ #logger.info("End of check_input_token_length function")
80
+
81
+
82
+ with gr.Blocks(css='style.css') as demo:
83
+ gr.Markdown(DESCRIPTION)
84
+ gr.DuplicateButton(value='Duplicate Space for private use',
85
+ elem_id='duplicate-button')
86
+
87
+ with gr.Group():
88
+ chatbot = gr.Chatbot(label='Chatbot')
89
+ with gr.Row():
90
+ textbox = gr.Textbox(
91
+ container=False,
92
+ show_label=False,
93
+ placeholder='Type a message...',
94
+ scale=10,
95
+ )
96
+ submit_button = gr.Button('Submit',
97
+ variant='primary',
98
+ scale=1,
99
+ min_width=0)
100
+ with gr.Row():
101
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
102
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
103
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
104
+
105
+ saved_input = gr.State()
106
+
107
+ with gr.Accordion(label='Advanced options', open=False):
108
+ system_prompt = gr.Textbox(label='System prompt',
109
+ value=DEFAULT_SYSTEM_PROMPT,
110
+ lines=6)
111
+ max_new_tokens = gr.Slider(
112
+ label='Max new tokens',
113
+ minimum=1,
114
+ maximum=MAX_MAX_NEW_TOKENS,
115
+ step=1,
116
+ value=DEFAULT_MAX_NEW_TOKENS,
117
+ )
118
+ temperature = gr.Slider(
119
+ label='Temperature',
120
+ minimum=0.1,
121
+ maximum=4.0,
122
+ step=0.1,
123
+ value=1.0,
124
+ )
125
+ top_p = gr.Slider(
126
+ label='Top-p (nucleus sampling)',
127
+ minimum=0.05,
128
+ maximum=1.0,
129
+ step=0.05,
130
+ value=0.95,
131
+ )
132
+ top_k = gr.Slider(
133
+ label='Top-k',
134
+ minimum=1,
135
+ maximum=1000,
136
+ step=1,
137
+ value=50,
138
+ )
139
+
140
+ gr.Markdown(LICENSE)
141
+
142
+ textbox.submit(
143
+ fn=clear_and_save_textbox,
144
+ inputs=textbox,
145
+ outputs=[textbox, saved_input],
146
+ api_name=False,
147
+ queue=False,
148
+ ).then(
149
+ fn=display_input,
150
+ inputs=[saved_input, chatbot],
151
+ outputs=chatbot,
152
+ api_name=False,
153
+ queue=False,
154
+ ).then(
155
+ fn=check_input_token_length,
156
+ inputs=[saved_input, chatbot, system_prompt],
157
+ api_name=False,
158
+ queue=False,
159
+ ).success(
160
+ fn=generate,
161
+ inputs=[
162
+ saved_input,
163
+ chatbot,
164
+ system_prompt,
165
+ max_new_tokens,
166
+ temperature,
167
+ top_p,
168
+ top_k,
169
+ ],
170
+ outputs=chatbot,
171
+ api_name=False,
172
+ )
173
+
174
+ button_event_preprocess = submit_button.click(
175
+ fn=clear_and_save_textbox,
176
+ inputs=textbox,
177
+ outputs=[textbox, saved_input],
178
+ api_name=False,
179
+ queue=False,
180
+ ).then(
181
+ fn=display_input,
182
+ inputs=[saved_input, chatbot],
183
+ outputs=chatbot,
184
+ api_name=False,
185
+ queue=False,
186
+ ).then(
187
+ fn=check_input_token_length,
188
+ inputs=[saved_input, chatbot, system_prompt],
189
+ api_name=False,
190
+ queue=False,
191
+ ).success(
192
+ fn=generate,
193
+ inputs=[
194
+ saved_input,
195
+ chatbot,
196
+ system_prompt,
197
+ max_new_tokens,
198
+ temperature,
199
+ top_p,
200
+ top_k,
201
+ ],
202
+ outputs=chatbot,
203
+ api_name=False,
204
+ )
205
+
206
+ retry_button.click(
207
+ fn=delete_prev_fn,
208
+ inputs=chatbot,
209
+ outputs=[chatbot, saved_input],
210
+ api_name=False,
211
+ queue=False,
212
+ ).then(
213
+ fn=display_input,
214
+ inputs=[saved_input, chatbot],
215
+ outputs=chatbot,
216
+ api_name=False,
217
+ queue=False,
218
+ ).then(
219
+ fn=generate,
220
+ inputs=[
221
+ saved_input,
222
+ chatbot,
223
+ system_prompt,
224
+ max_new_tokens,
225
+ temperature,
226
+ top_p,
227
+ top_k,
228
+ ],
229
+ outputs=chatbot,
230
+ api_name=False,
231
+ )
232
+
233
+ undo_button.click(
234
+ fn=delete_prev_fn,
235
+ inputs=chatbot,
236
+ outputs=[chatbot, saved_input],
237
+ api_name=False,
238
+ queue=False,
239
+ ).then(
240
+ fn=lambda x: x,
241
+ inputs=[saved_input],
242
+ outputs=textbox,
243
+ api_name=False,
244
+ queue=False,
245
+ )
246
+
247
+ clear_button.click(
248
+ fn=lambda: ([], ''),
249
+ outputs=[chatbot, saved_input],
250
+ queue=False,
251
+ api_name=False,
252
+ )
253
+
254
+ demo.queue(max_size=20).launch(share=False, server_name="0.0.0.0")
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
- fastapi
2
- uvicorn[standard]
 
 
 
 
 
 
1
+ gradio==3.37.0
2
+ protobuf==3.20.3
3
+ scipy==1.11.1
4
+ torch==2.0.1
5
+ sentencepiece==0.1.99
6
+ transformers==4.31.0
7
+ ctransformers==0.2.27