hsienchen commited on
Commit
867a443
·
verified ·
1 Parent(s): 2240977

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ from typing import List, Tuple, Optional, Dict, Union
5
+
6
+ import google.generativeai as genai
7
+ import gradio as gr
8
+ from PIL import Image
9
+
10
+ print("google-generativeai:", genai.__version__)
11
+
12
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
13
+
14
+
15
+ AVATAR_IMAGES = (
16
+ None,
17
+ "https://media.roboflow.com/spaces/gemini-icon.png"
18
+ )
19
+
20
+ IMAGE_CACHE_DIRECTORY = "/tmp"
21
+ IMAGE_WIDTH = 512
22
+ CHAT_HISTORY = List[Tuple[Optional[Union[Tuple[str], str]], Optional[str]]]
23
+
24
+
25
+ def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
26
+ if not stop_sequences:
27
+ return None
28
+ return [sequence.strip() for sequence in stop_sequences.split(",")]
29
+
30
+
31
+ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
32
+ image_height = int(image.height * IMAGE_WIDTH / image.width)
33
+ return image.resize((IMAGE_WIDTH, image_height))
34
+
35
+
36
+ def cache_pil_image(image: Image.Image) -> str:
37
+ image_filename = f"{uuid.uuid4()}.jpeg"
38
+ os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
39
+ image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
40
+ image.save(image_path, "JPEG")
41
+ return image_path
42
+
43
+
44
+ def preprocess_chat_history(
45
+ history: CHAT_HISTORY
46
+ ) -> List[Dict[str, Union[str, List[str]]]]:
47
+ messages = []
48
+ for user_message, model_message in history:
49
+ if isinstance(user_message, tuple):
50
+ pass
51
+ elif user_message is not None:
52
+ messages.append({'role': 'user', 'parts': [user_message]})
53
+ if model_message is not None:
54
+ messages.append({'role': 'model', 'parts': [model_message]})
55
+ return messages
56
+
57
+
58
+ def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
59
+ for file in files:
60
+ image = Image.open(file).convert('RGB')
61
+ image = preprocess_image(image)
62
+ image_path = cache_pil_image(image)
63
+ chatbot.append(((image_path,), None))
64
+ return chatbot
65
+
66
+
67
+ def user(text_prompt: str, chatbot: CHAT_HISTORY):
68
+ if text_prompt:
69
+ chatbot.append((text_prompt, None))
70
+ return "", chatbot
71
+
72
+
73
+ def bot(
74
+ google_key: str,
75
+ files: Optional[List[str]],
76
+ temperature: float,
77
+ max_output_tokens: int,
78
+ stop_sequences: str,
79
+ top_k: int,
80
+ top_p: float,
81
+ chatbot: CHAT_HISTORY
82
+ ):
83
+ if len(chatbot) == 0:
84
+ return chatbot
85
+
86
+ google_key = google_key if google_key else GOOGLE_API_KEY
87
+ if not google_key:
88
+ raise ValueError(
89
+ "GOOGLE_API_KEY is not set. "
90
+ "Please follow the instructions in the README to set it up.")
91
+
92
+ genai.configure(api_key=google_key)
93
+ generation_config = genai.types.GenerationConfig(
94
+ temperature=temperature,
95
+ max_output_tokens=max_output_tokens,
96
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
97
+ top_k=top_k,
98
+ top_p=top_p)
99
+
100
+ if files:
101
+ text_prompt = [chatbot[-1][0]] \
102
+ if chatbot[-1][0] and isinstance(chatbot[-1][0], str) \
103
+ else []
104
+ image_prompt = [Image.open(file).convert('RGB') for file in files]
105
+ model = genai.GenerativeModel('gemini-pro-vision')
106
+ response = model.generate_content(
107
+ text_prompt + image_prompt,
108
+ stream=True,
109
+ generation_config=generation_config)
110
+ else:
111
+ messages = preprocess_chat_history(chatbot)
112
+ model = genai.GenerativeModel('gemini-pro')
113
+ response = model.generate_content(
114
+ messages,
115
+ stream=True,
116
+ generation_config=generation_config)
117
+
118
+ # streaming effect
119
+ chatbot[-1][1] = ""
120
+ for chunk in response:
121
+ for i in range(0, len(chunk.text), 10):
122
+ section = chunk.text[i:i + 10]
123
+ chatbot[-1][1] += section
124
+ time.sleep(0.01)
125
+ yield chatbot
126
+
127
+
128
+ google_key_component = gr.Textbox(
129
+ label="GOOGLE API KEY",
130
+ value="",
131
+ type="password",
132
+ placeholder="...",
133
+ info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
134
+ visible=GOOGLE_API_KEY is None
135
+ )
136
+ chatbot_component = gr.Chatbot(
137
+ label='Gemini',
138
+ bubble_full_width=False,
139
+ avatar_images=AVATAR_IMAGES,
140
+ scale=2,
141
+ height=400
142
+ )
143
+ text_prompt_component = gr.Textbox(
144
+ placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8
145
+ )
146
+ upload_button_component = gr.UploadButton(
147
+ label="Upload Images", file_count="multiple", file_types=["image"], scale=1
148
+ )
149
+ run_button_component = gr.Button(value="Run", variant="primary", scale=1)
150
+ temperature_component = gr.Slider(
151
+ minimum=0,
152
+ maximum=1.0,
153
+ value=0.4,
154
+ step=0.05,
155
+ label="Temperature",
156
+ info=(
157
+ "Temperature controls the degree of randomness in token selection. Lower "
158
+ "temperatures are good for prompts that expect a true or correct response, "
159
+ "while higher temperatures can lead to more diverse or unexpected results. "
160
+ ))
161
+ max_output_tokens_component = gr.Slider(
162
+ minimum=1,
163
+ maximum=2048,
164
+ value=1024,
165
+ step=1,
166
+ label="Token limit",
167
+ info=(
168
+ "Token limit determines the maximum amount of text output from one prompt. A "
169
+ "token is approximately four characters. The default value is 2048."
170
+ ))
171
+ stop_sequences_component = gr.Textbox(
172
+ label="Add stop sequence",
173
+ value="",
174
+ type="text",
175
+ placeholder="STOP, END",
176
+ info=(
177
+ "A stop sequence is a series of characters (including spaces) that stops "
178
+ "response generation if the model encounters it. The sequence is not included "
179
+ "as part of the response. You can add up to five stop sequences."
180
+ ))
181
+ top_k_component = gr.Slider(
182
+ minimum=1,
183
+ maximum=40,
184
+ value=32,
185
+ step=1,
186
+ label="Top-K",
187
+ info=(
188
+ "Top-k changes how the model selects tokens for output. A top-k of 1 means the "
189
+ "selected token is the most probable among all tokens in the model’s "
190
+ "vocabulary (also called greedy decoding), while a top-k of 3 means that the "
191
+ "next token is selected from among the 3 most probable tokens (using "
192
+ "temperature)."
193
+ ))
194
+ top_p_component = gr.Slider(
195
+ minimum=0,
196
+ maximum=1,
197
+ value=1,
198
+ step=0.01,
199
+ label="Top-P",
200
+ info=(
201
+ "Top-p changes how the model selects tokens for output. Tokens are selected "
202
+ "from most probable to least until the sum of their probabilities equals the "
203
+ "top-p value. For example, if tokens A, B, and C have a probability of .3, .2, "
204
+ "and .1 and the top-p value is .5, then the model will select either A or B as "
205
+ "the next token (using temperature). "
206
+ ))
207
+
208
+ user_inputs = [
209
+ text_prompt_component,
210
+ chatbot_component
211
+ ]
212
+
213
+ bot_inputs = [
214
+ google_key_component,
215
+ upload_button_component,
216
+ temperature_component,
217
+ max_output_tokens_component,
218
+ stop_sequences_component,
219
+ top_k_component,
220
+ top_p_component,
221
+ chatbot_component
222
+ ]
223
+
224
+ with gr.Blocks() as demo:
225
+ gr.HTML(TITLE)
226
+ gr.HTML(SUBTITLE)
227
+ gr.HTML(DUPLICATE)
228
+ with gr.Column():
229
+ google_key_component.render()
230
+ chatbot_component.render()
231
+ with gr.Row():
232
+ text_prompt_component.render()
233
+ upload_button_component.render()
234
+ run_button_component.render()
235
+ with gr.Accordion("Parameters", open=False):
236
+ temperature_component.render()
237
+ max_output_tokens_component.render()
238
+ stop_sequences_component.render()
239
+ with gr.Accordion("Advanced", open=False):
240
+ top_k_component.render()
241
+ top_p_component.render()
242
+
243
+ run_button_component.click(
244
+ fn=user,
245
+ inputs=user_inputs,
246
+ outputs=[text_prompt_component, chatbot_component],
247
+ queue=False
248
+ ).then(
249
+ fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
250
+ )
251
+
252
+ text_prompt_component.submit(
253
+ fn=user,
254
+ inputs=user_inputs,
255
+ outputs=[text_prompt_component, chatbot_component],
256
+ queue=False
257
+ ).then(
258
+ fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
259
+ )
260
+
261
+ upload_button_component.upload(
262
+ fn=upload,
263
+ inputs=[upload_button_component, chatbot_component],
264
+ outputs=[chatbot_component],
265
+ queue=False
266
+ )
267
+
268
+ demo.queue(max_size=99).launch(debug=False, show_error=True)