multimodalart HF Staff commited on
Commit
3aa2ce4
·
verified ·
1 Parent(s): e5da62c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -141
app.py CHANGED
@@ -1,24 +1,19 @@
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
- from google import genai
4
- from google.genai import types
5
  import os
6
  from typing import Optional, List
7
  from huggingface_hub import whoami
8
  from PIL import Image
9
  from io import BytesIO
10
  import tempfile
11
- import time
12
 
13
  # --- Google Gemini API Configuration ---
14
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
15
  if not GOOGLE_API_KEY:
16
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
17
-
18
- client = genai.Client(
19
- api_key=os.environ.get("GOOGLE_API_KEY"),
20
- )
21
-
22
  GEMINI_MODEL_NAME = 'gemini-2.5-flash-image-preview'
23
 
24
  def verify_pro_status(token: Optional[gr.OAuthToken]) -> bool:
@@ -27,12 +22,7 @@ def verify_pro_status(token: Optional[gr.OAuthToken]) -> bool:
27
  return False
28
  try:
29
  user_info = whoami(token=token.token)
30
- if user_info.get("isPro", False):
31
- return True
32
- orgs = user_info.get("orgs", [])
33
- if any(org.get("isEnterprise", False) for org in orgs):
34
- return True
35
- return False
36
  except Exception as e:
37
  print(f"Could not verify user's PRO/Enterprise status: {e}")
38
  return False
@@ -40,104 +30,133 @@ def verify_pro_status(token: Optional[gr.OAuthToken]) -> bool:
40
  def _extract_image_data_from_response(response) -> Optional[bytes]:
41
  """Helper to extract image data from the model's response."""
42
  if hasattr(response, 'candidates') and response.candidates:
43
- for candidate in response.candidates:
44
- if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts') and candidate.content.parts:
45
- for part in candidate.content.parts:
46
- if hasattr(part, 'inline_data') and hasattr(part.inline_data, 'data'):
47
- return part.inline_data.data
48
  return None
49
 
50
- def unified_image_generator(
51
- prompt: str,
52
- images: Optional[List[str]] = None,
53
- oauth_token: Optional[gr.OAuthToken] = None
54
- ) -> tuple:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  """
56
- Handles all image generation tasks based on the number of input images.
57
- Returns: (output_image_path, video_button_visible, video_output_visible)
58
  """
59
- if not verify_pro_status(oauth_token):
60
- raise gr.Error("Access Denied. This service is for PRO users only.")
 
 
 
 
 
 
61
 
62
  try:
63
- # Dynamically build the 'contents' list for the API
64
- contents = []
65
- if images:
66
- # If there are images, open them and add to contents
67
- for image_path in images:
68
- print(image_path)
69
- contents.append(Image.open(image_path[0]))
70
-
71
- # Always add the prompt to the contents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  contents.append(prompt)
73
-
74
- response = None
75
- for attempt in range(0, 3):
76
- try:
77
- response = client.models.generate_content(
78
- model=GEMINI_MODEL_NAME,
79
- contents=contents,
80
- )
81
- break
82
- except Exception as e:
83
- if attempt == 2:
84
- raise gr.Error(f"The Gemini API returned an error: {e}")
85
- time.sleep(1)
86
-
87
  image_data = _extract_image_data_from_response(response)
 
88
 
89
- if not image_data:
90
- raise ValueError("No image data found in the model response.")
91
-
92
- # Save the generated image to a temporary file to return its path
93
- pil_image = Image.open(BytesIO(image_data))
94
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile:
95
- pil_image.save(tmpfile.name)
96
- output_path = tmpfile.name
97
 
98
- # Determine if video button should be shown (only if exactly 1 input image)
99
- show_video_button = images and len(images) == 1
100
 
101
- # Return output image path, video button visibility, and hide video output
102
- return output_path, gr.update(visible=show_video_button), gr.update(visible=False)
103
-
 
 
 
104
  except Exception as e:
105
  raise gr.Error(f"Image generation failed: {e}")
106
 
107
- def create_video_transition(
108
- input_image_gallery: List[str],
109
- prompt_input: str,
110
- output_image: str,
111
- oauth_token: Optional[gr.OAuthToken] = None
112
- ) -> tuple:
113
- """
114
- Creates a video transition between the input and output images.
115
- Returns: (video_path, video_visible)
116
- """
117
- if not verify_pro_status(oauth_token):
118
- raise gr.Error("Access Denied. This service is for PRO users only.")
119
-
120
- if not input_image_gallery or not output_image:
121
- raise gr.Error("Both input and output images are required for video creation.")
122
-
123
  try:
124
- video_client = Client("multimodalart/wan-2-2-first-last-frame", hf_token=oauth_token.token)
125
-
126
- input_image_path = input_image_gallery[0][0]
127
-
128
- result = video_client.predict(
129
- start_image_pil=handle_file(input_image_path),
130
- end_image_pil=handle_file(output_image),
131
- prompt=prompt_input,
132
- api_name="/generate_video"
133
- )
134
- print(result)
135
- return result[0]["video"]
136
-
137
  except Exception as e:
138
  raise gr.Error(f"Video creation failed: {e}")
139
 
140
- # --- Gradio App UI ---
 
 
 
 
 
 
 
 
 
 
 
 
141
  css = '''
142
  #sub_title{margin-top: -35px !important}
143
  .tab-wrapper{margin-bottom: -33px !important}
@@ -158,77 +177,68 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
158
  <img class="logo-dark" src='https://huggingface.co/spaces/multimodalart/nano-banana/resolve/main/nano_banana_pros.png' style='margin: 0 auto; max-width: 500px' />
159
  <img class="logo-light" src='https://huggingface.co/spaces/multimodalart/nano-banana/resolve/main/nano_banana_pros_light.png' style='margin: 0 auto; max-width: 500px' />
160
  ''')
161
-
162
  gr.HTML("<h3 style='text-align:center'>Hugging Face PRO users can use Google's Nano Banana (Gemini 2.5 Flash Image Preview) on this Space. <a href='http://huggingface.co/subscribe/pro?source=nana_banana' target='_blank'>Subscribe to PRO</a></h3>", elem_id="sub_title")
163
-
164
  pro_message = gr.Markdown(visible=False)
165
  main_interface = gr.Column(visible=False)
 
166
 
167
  with main_interface:
168
  with gr.Row():
169
  with gr.Column(scale=1):
170
- with gr.Group():
171
- image_input_gallery = gr.Gallery(
172
- label="Upload one or more images here. Leave empty for text-to-image",
173
- file_types=["image"],
174
- height="auto"
175
- )
176
-
177
- prompt_input = gr.Textbox(
178
- label="Prompt",
179
- placeholder="Turns this photo into a masterpiece"
180
- )
181
- generate_button = gr.Button("Generate", variant="primary")
182
-
183
  with gr.Column(scale=1):
184
  output_image = gr.Image(label="Output", interactive=False, elem_id="output", type="filepath")
185
- use_image_button = gr.Button("♻️ Use this Image for Next Edit")
186
- create_video_button = gr.Button("Create a video between the two images 🎥", variant="primary", visible=False)
 
 
187
  with gr.Group(visible=False) as video_group:
188
  video_output = gr.Video(label="Generated Video", show_download_button=True, autoplay=True)
189
  gr.Markdown("Generate more with [Wan 2.2 first-last-frame](https://huggingface.co/spaces/multimodalart/wan-2-2-first-last-frame)", elem_id="wan_ad")
190
  gr.Markdown("## Thank you for being a PRO! 🤗")
191
-
192
  login_button = gr.LoginButton()
193
-
194
- # --- Event Handlers ---
195
  gr.on(
196
  triggers=[generate_button.click, prompt_input.submit],
197
- fn=lambda: [gr.update(visible=False), gr.update(visible=False)],
198
- inputs=[],
199
- outputs=[create_video_button, video_group],
200
- ).then(
201
  fn=unified_image_generator,
202
- inputs=[prompt_input, image_input_gallery],
203
- outputs=[output_image, create_video_button, video_group],
204
  )
205
 
206
  use_image_button.click(
207
- lambda img_path: [img_path] if img_path else None,
 
 
 
 
 
 
208
  inputs=[output_image],
209
- outputs=[image_input_gallery]
210
  )
211
-
212
- # Video creation handler
213
  create_video_button.click(
214
- fn=lambda: gr.update(visible=True),
215
- inputs=[],
216
- outputs=[video_group],
217
  ).then(
218
- fn=create_video_transition,
219
  inputs=[image_input_gallery, prompt_input, output_image],
220
- outputs=[video_output],
221
  )
222
 
223
- # --- Access Control Logic ---
224
- def control_access(
225
- profile: Optional[gr.OAuthProfile] = None,
226
- oauth_token: Optional[gr.OAuthToken] = None
227
- ):
228
- if not profile:
229
- return gr.update(visible=False), gr.update(visible=False)
230
- if verify_pro_status(oauth_token):
231
- return gr.update(visible=True), gr.update(visible=False)
 
 
232
  else:
233
  message = (
234
  "## ✨ Exclusive Access for PRO Users\n\n"
@@ -237,9 +247,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
237
  "### [**Become a PRO Today!**](http://huggingface.co/subscribe/pro?source=nana_banana)"
238
  )
239
  return gr.update(visible=False), gr.update(visible=True, value=message)
240
-
241
  demo.load(control_access, inputs=None, outputs=[main_interface, pro_message])
242
 
243
  if __name__ == "__main__":
244
- demo.queue(max_size=None, default_concurrency_limit=None)
245
- demo.launch()
 
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
+ from google import genai
 
4
  import os
5
  from typing import Optional, List
6
  from huggingface_hub import whoami
7
  from PIL import Image
8
  from io import BytesIO
9
  import tempfile
10
+ import ffmpeg
11
 
12
  # --- Google Gemini API Configuration ---
13
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
14
  if not GOOGLE_API_KEY:
15
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
16
+ client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY"))
 
 
 
 
17
  GEMINI_MODEL_NAME = 'gemini-2.5-flash-image-preview'
18
 
19
  def verify_pro_status(token: Optional[gr.OAuthToken]) -> bool:
 
22
  return False
23
  try:
24
  user_info = whoami(token=token.token)
25
+ return user_info.get("isPro", False) or any(org.get("isEnterprise", False) for org in user_info.get("orgs", []))
 
 
 
 
 
26
  except Exception as e:
27
  print(f"Could not verify user's PRO/Enterprise status: {e}")
28
  return False
 
30
  def _extract_image_data_from_response(response) -> Optional[bytes]:
31
  """Helper to extract image data from the model's response."""
32
  if hasattr(response, 'candidates') and response.candidates:
33
+ for part in response.candidates[0].content.parts:
34
+ if hasattr(part, 'inline_data') and hasattr(part.inline_data, 'data'):
35
+ return part.inline_data.data
 
 
36
  return None
37
 
38
+ def _get_framerate(video_path: str) -> float:
39
+ """Instantly gets the framerate of a video using ffprobe."""
40
+ probe = ffmpeg.probe(video_path)
41
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
42
+ if video_stream is None:
43
+ raise ValueError("Could not find video stream in the file.")
44
+ return eval(video_stream['avg_frame_rate'])
45
+
46
+ def _trim_first_frame_fast(video_path: str) -> str:
47
+ """
48
+ Removes exactly the first frame of a video without re-encoding.
49
+ This is the frame-accurate and fast method.
50
+ """
51
+ gr.Info("Preparing video segment...")
52
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output_file:
53
+ output_path = tmp_output_file.name
54
+
55
+ try:
56
+ framerate = _get_framerate(video_path)
57
+ if framerate == 0: raise ValueError("Framerate cannot be zero.")
58
+ start_time = 1 / framerate
59
+
60
+ # The key is placing -ss AFTER -i for accuracy, combined with -c copy for speed.
61
+ (
62
+ ffmpeg
63
+ .input(video_path, ss=start_time)
64
+ .output(output_path, c='copy', avoid_negative_ts='make_zero')
65
+ .run(overwrite_output=True, quiet=True)
66
+ )
67
+ return output_path
68
+ except Exception as e:
69
+ raise RuntimeError(f"FFmpeg trim error: {e}")
70
+
71
+ def _combine_videos_simple(video1_path: str, video2_path: str) -> str:
72
  """
73
+ Combines two videos using the fast concat demuxer. Assumes video2 is already trimmed.
 
74
  """
75
+ gr.Info("Stitching videos...")
76
+ with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix=".txt") as tmp_list_file:
77
+ tmp_list_file.write(f"file '{os.path.abspath(video1_path)}'\n")
78
+ tmp_list_file.write(f"file '{os.path.abspath(video2_path)}'\n")
79
+ list_file_path = tmp_list_file.name
80
+
81
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output_file:
82
+ output_path = tmp_output_file.name
83
 
84
  try:
85
+ (
86
+ ffmpeg
87
+ .input(list_file_path, format='concat', safe=0)
88
+ .output(output_path, c='copy')
89
+ .run(overwrite_output=True, quiet=True)
90
+ )
91
+ return output_path
92
+ except ffmpeg.Error as e:
93
+ raise RuntimeError(f"FFmpeg combine error: {e.stderr.decode()}")
94
+ finally:
95
+ if os.path.exists(list_file_path):
96
+ os.remove(list_file_path)
97
+
98
+ def _generate_video_segment(input_image_path: str, output_image_path: str, prompt: str, token: str) -> str:
99
+ """Generates a single video segment using the external service."""
100
+ gr.Info("Generating new video segment...")
101
+ video_client = Client("multimodalart/wan-2-2-first-last-frame", hf_token=token)
102
+ result = video_client.predict(
103
+ start_image_pil=handle_file(input_image_path),
104
+ end_image_pil=handle_file(output_image_path),
105
+ prompt=prompt, api_name="/generate_video"
106
+ )
107
+ return result[0]["video"]
108
+
109
+ def unified_image_generator(prompt: str, images: Optional[List[str]], previous_video_path: Optional[str], oauth_token: Optional[gr.OAuthToken]) -> tuple:
110
+ """
111
+ Handles image generation and determines the visibility of video creation buttons.
112
+ """
113
+ if not verify_pro_status(oauth_token): raise gr.Error("Access Denied.")
114
+ try:
115
+ contents = [Image.open(image_path[0]) for image_path in images] if images else []
116
  contents.append(prompt)
117
+ response = client.models.generate_content(model=GEMINI_MODEL_NAME, contents=contents)
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  image_data = _extract_image_data_from_response(response)
119
+ if not image_data: raise ValueError("No image data in response.")
120
 
121
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
122
+ Image.open(BytesIO(image_data)).save(tmp.name)
123
+ output_path = tmp.name
 
 
 
 
 
124
 
125
+ can_create_video = bool(images and len(images) == 1)
126
+ can_extend_video = can_create_video and bool(previous_video_path)
127
 
128
+ return (
129
+ output_path,
130
+ gr.update(visible=can_create_video),
131
+ gr.update(visible=can_extend_video),
132
+ gr.update(visible=False)
133
+ )
134
  except Exception as e:
135
  raise gr.Error(f"Image generation failed: {e}")
136
 
137
+ def create_new_video(input_image_gallery: List[str], prompt_input: str, output_image: str, oauth_token: Optional[gr.OAuthToken]) -> tuple:
138
+ """Starts a NEW video chain, overwriting any previous video state."""
139
+ if not verify_pro_status(oauth_token): raise gr.Error("Access Denied.")
140
+ if not input_image_gallery or not output_image: raise gr.Error("Input/output images required.")
 
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
+ new_segment_path = _generate_video_segment(input_image_gallery[0][0], output_image, prompt_input, oauth_token.token)
143
+ return new_segment_path, new_segment_path
 
 
 
 
 
 
 
 
 
 
 
144
  except Exception as e:
145
  raise gr.Error(f"Video creation failed: {e}")
146
 
147
+ def extend_existing_video(input_image_gallery: List[str], prompt_input: str, output_image: str, previous_video_path: str, oauth_token: Optional[gr.OAuthToken]) -> tuple:
148
+ """Extends an existing video with a new segment."""
149
+ if not verify_pro_status(oauth_token): raise gr.Error("Access Denied.")
150
+ if not previous_video_path: raise gr.Error("No previous video to extend.")
151
+ if not input_image_gallery or not output_image: raise gr.Error("Input/output images required.")
152
+ try:
153
+ new_segment_path = _generate_video_segment(input_image_gallery[0][0], output_image, prompt_input, oauth_token.token)
154
+ trimmed_segment_path = _trim_first_frame_fast(new_segment_path)
155
+ final_video_path = _combine_videos_simple(previous_video_path, trimmed_segment_path)
156
+ return final_video_path, final_video_path
157
+ except Exception as e:
158
+ raise gr.Error(f"Video extension failed: {e}")
159
+
160
  css = '''
161
  #sub_title{margin-top: -35px !important}
162
  .tab-wrapper{margin-bottom: -33px !important}
 
177
  <img class="logo-dark" src='https://huggingface.co/spaces/multimodalart/nano-banana/resolve/main/nano_banana_pros.png' style='margin: 0 auto; max-width: 500px' />
178
  <img class="logo-light" src='https://huggingface.co/spaces/multimodalart/nano-banana/resolve/main/nano_banana_pros_light.png' style='margin: 0 auto; max-width: 500px' />
179
  ''')
 
180
  gr.HTML("<h3 style='text-align:center'>Hugging Face PRO users can use Google's Nano Banana (Gemini 2.5 Flash Image Preview) on this Space. <a href='http://huggingface.co/subscribe/pro?source=nana_banana' target='_blank'>Subscribe to PRO</a></h3>", elem_id="sub_title")
 
181
  pro_message = gr.Markdown(visible=False)
182
  main_interface = gr.Column(visible=False)
183
+ previous_video_state = gr.State(None)
184
 
185
  with main_interface:
186
  with gr.Row():
187
  with gr.Column(scale=1):
188
+ image_input_gallery = gr.Gallery(label="Upload one or more images here. Leave empty for text-to-image", file_types=["image"], height="auto")
189
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Turns this photo into a masterpiece")
190
+ generate_button = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
191
  with gr.Column(scale=1):
192
  output_image = gr.Image(label="Output", interactive=False, elem_id="output", type="filepath")
193
+ use_image_button = gr.Button("♻️ Use this Image for Next Edit", variant="primary")
194
+ with gr.Row():
195
+ create_video_button = gr.Button("Create video between the two images 🎥", variant="secondary", visible=False)
196
+ extend_video_button = gr.Button("Extend previous video with new scene 🎞️", variant="secondary", visible=False)
197
  with gr.Group(visible=False) as video_group:
198
  video_output = gr.Video(label="Generated Video", show_download_button=True, autoplay=True)
199
  gr.Markdown("Generate more with [Wan 2.2 first-last-frame](https://huggingface.co/spaces/multimodalart/wan-2-2-first-last-frame)", elem_id="wan_ad")
200
  gr.Markdown("## Thank you for being a PRO! 🤗")
201
+
202
  login_button = gr.LoginButton()
203
+
 
204
  gr.on(
205
  triggers=[generate_button.click, prompt_input.submit],
 
 
 
 
206
  fn=unified_image_generator,
207
+ inputs=[prompt_input, image_input_gallery, previous_video_state],
208
+ outputs=[output_image, create_video_button, extend_video_button, video_group]
209
  )
210
 
211
  use_image_button.click(
212
+ fn=lambda img: (
213
+ [img] if img else None,
214
+ None,
215
+ gr.update(visible=False),
216
+ gr.update(visible=False),
217
+ gr.update(visible=False)
218
+ ),
219
  inputs=[output_image],
220
+ outputs=[image_input_gallery, output_image, create_video_button, extend_video_button, video_group]
221
  )
222
+
 
223
  create_video_button.click(
224
+ fn=lambda: gr.update(visible=True), outputs=[video_group]
 
 
225
  ).then(
226
+ fn=create_new_video,
227
  inputs=[image_input_gallery, prompt_input, output_image],
228
+ outputs=[video_output, previous_video_state],
229
  )
230
 
231
+ extend_video_button.click(
232
+ fn=lambda: gr.update(visible=True), outputs=[video_group]
233
+ ).then(
234
+ fn=extend_existing_video,
235
+ inputs=[image_input_gallery, prompt_input, output_image, previous_video_state],
236
+ outputs=[video_output, previous_video_state],
237
+ )
238
+
239
+ def control_access(profile: Optional[gr.OAuthProfile] = None, oauth_token: Optional[gr.OAuthToken] = None):
240
+ if not profile: return gr.update(visible=False), gr.update(visible=False)
241
+ if verify_pro_status(oauth_token): return gr.update(visible=True), gr.update(visible=False)
242
  else:
243
  message = (
244
  "## ✨ Exclusive Access for PRO Users\n\n"
 
247
  "### [**Become a PRO Today!**](http://huggingface.co/subscribe/pro?source=nana_banana)"
248
  )
249
  return gr.update(visible=False), gr.update(visible=True, value=message)
 
250
  demo.load(control_access, inputs=None, outputs=[main_interface, pro_message])
251
 
252
  if __name__ == "__main__":
253
+ demo.queue(max_size=None, default_concurrency_limit=None).launch()