multimodalart HF Staff commited on
Commit
aa0696d
·
verified ·
1 Parent(s): ae7dd23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -50,6 +50,7 @@ def _resize_image(image_path: str, target_size: Tuple[int, int]) -> str:
50
  with Image.open(image_path) as img:
51
  if img.size == target_size:
52
  return image_path
 
53
  resized_img = img.resize(target_size, Image.Resampling.LANCZOS)
54
  suffix = os.path.splitext(image_path)[1] or ".png"
55
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
@@ -58,6 +59,7 @@ def _resize_image(image_path: str, target_size: Tuple[int, int]) -> str:
58
 
59
  def _trim_first_frame_fast(video_path: str) -> str:
60
  """Removes exactly the first frame of a video without re-encoding."""
 
61
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output_file:
62
  output_path = tmp_output_file.name
63
  try:
@@ -76,7 +78,7 @@ def _trim_first_frame_fast(video_path: str) -> str:
76
 
77
  def _combine_videos_simple(video1_path: str, video2_path: str) -> str:
78
  """Combines two videos using the fast concat demuxer."""
79
-
80
  with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix=".txt") as tmp_list_file:
81
  tmp_list_file.write(f"file '{os.path.abspath(video1_path)}'\n")
82
  tmp_list_file.write(f"file '{os.path.abspath(video2_path)}'\n")
@@ -99,6 +101,7 @@ def _combine_videos_simple(video1_path: str, video2_path: str) -> str:
99
 
100
  def _generate_video_segment(input_image_path: str, output_image_path: str, prompt: str, token: str) -> str:
101
  """Generates a single video segment using the external service."""
 
102
  video_client = Client("multimodalart/wan-2-2-first-last-frame", hf_token=token)
103
  result = video_client.predict(
104
  start_image_pil=handle_file(input_image_path),
@@ -107,7 +110,7 @@ def _generate_video_segment(input_image_path: str, output_image_path: str, promp
107
  )
108
  return result[0]["video"]
109
 
110
- def unified_image_generator(prompt: str, images: Optional[List[str]], previous_video_path: Optional[str], oauth_token: Optional[gr.OAuthToken]) -> tuple:
111
  if not verify_pro_status(oauth_token): raise gr.Error("Access Denied.")
112
  try:
113
  contents = [Image.open(image_path[0]) for image_path in images] if images else []
@@ -118,8 +121,14 @@ def unified_image_generator(prompt: str, images: Optional[List[str]], previous_v
118
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
119
  Image.open(BytesIO(image_data)).save(tmp.name)
120
  output_path = tmp.name
 
121
  can_create_video = bool(images and len(images) == 1)
122
- can_extend_video = can_create_video and bool(previous_video_path)
 
 
 
 
 
123
  return (output_path, gr.update(visible=can_create_video), gr.update(visible=can_extend_video), gr.update(visible=False))
124
  except Exception as e:
125
  raise gr.Error(f"Image generation failed: {e}")
@@ -129,7 +138,7 @@ def create_new_video(input_image_gallery: List[str], prompt_input: str, output_i
129
  if not input_image_gallery or not output_image: raise gr.Error("Input/output images required.")
130
  try:
131
  new_segment_path = _generate_video_segment(input_image_gallery[0][0], output_image, prompt_input, oauth_token.token)
132
- return new_segment_path, new_segment_path
133
  except Exception as e:
134
  raise gr.Error(f"Video creation failed: {e}")
135
 
@@ -144,7 +153,7 @@ def extend_existing_video(input_image_gallery: List[str], prompt_input: str, out
144
  new_segment_path = _generate_video_segment(resized_input_path, resized_output_path, prompt_input, oauth_token.token)
145
  trimmed_segment_path = _trim_first_frame_fast(new_segment_path)
146
  final_video_path = _combine_videos_simple(previous_video_path, trimmed_segment_path)
147
- return final_video_path, final_video_path
148
  except Exception as e:
149
  raise gr.Error(f"Video extension failed: {e}")
150
 
@@ -171,7 +180,9 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
171
  gr.HTML("<h3 style='text-align:center'>Hugging Face PRO users can use Google's Nano Banana (Gemini 2.5 Flash Image Preview) on this Space. <a href='http://huggingface.co/subscribe/pro?source=nana_banana' target='_blank'>Subscribe to PRO</a></h3>", elem_id="sub_title")
172
  pro_message = gr.Markdown(visible=False)
173
  main_interface = gr.Column(visible=False)
 
174
  previous_video_state = gr.State(None)
 
175
 
176
  with main_interface:
177
  with gr.Row():
@@ -195,7 +206,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
195
  gr.on(
196
  triggers=[generate_button.click, prompt_input.submit],
197
  fn=unified_image_generator,
198
- inputs=[prompt_input, image_input_gallery, previous_video_state],
199
  outputs=[output_image, create_video_button, extend_video_button, video_group]
200
  )
201
  use_image_button.click(
@@ -211,14 +222,14 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
211
  ).then(
212
  fn=create_new_video,
213
  inputs=[image_input_gallery, prompt_input, output_image],
214
- outputs=[video_output, previous_video_state],
215
  )
216
  extend_video_button.click(
217
  fn=lambda: gr.update(visible=True), outputs=[video_group]
218
  ).then(
219
  fn=extend_existing_video,
220
  inputs=[image_input_gallery, prompt_input, output_image, previous_video_state],
221
- outputs=[video_output, previous_video_state],
222
  )
223
 
224
  def control_access(profile: Optional[gr.OAuthProfile] = None, oauth_token: Optional[gr.OAuthToken] = None):
 
50
  with Image.open(image_path) as img:
51
  if img.size == target_size:
52
  return image_path
53
+ gr.Info(f"Resizing image to {target_size[0]}x{target_size[1]} to match previous video.")
54
  resized_img = img.resize(target_size, Image.Resampling.LANCZOS)
55
  suffix = os.path.splitext(image_path)[1] or ".png"
56
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
 
59
 
60
  def _trim_first_frame_fast(video_path: str) -> str:
61
  """Removes exactly the first frame of a video without re-encoding."""
62
+ gr.Info("Preparing video segment...")
63
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_output_file:
64
  output_path = tmp_output_file.name
65
  try:
 
78
 
79
  def _combine_videos_simple(video1_path: str, video2_path: str) -> str:
80
  """Combines two videos using the fast concat demuxer."""
81
+ gr.Info("Stitching videos...")
82
  with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix=".txt") as tmp_list_file:
83
  tmp_list_file.write(f"file '{os.path.abspath(video1_path)}'\n")
84
  tmp_list_file.write(f"file '{os.path.abspath(video2_path)}'\n")
 
101
 
102
  def _generate_video_segment(input_image_path: str, output_image_path: str, prompt: str, token: str) -> str:
103
  """Generates a single video segment using the external service."""
104
+ gr.Info("Generating new video segment...")
105
  video_client = Client("multimodalart/wan-2-2-first-last-frame", hf_token=token)
106
  result = video_client.predict(
107
  start_image_pil=handle_file(input_image_path),
 
110
  )
111
  return result[0]["video"]
112
 
113
+ def unified_image_generator(prompt: str, images: Optional[List[str]], previous_video_path: Optional[str], last_frame_path: Optional[str], oauth_token: Optional[gr.OAuthToken]) -> tuple:
114
  if not verify_pro_status(oauth_token): raise gr.Error("Access Denied.")
115
  try:
116
  contents = [Image.open(image_path[0]) for image_path in images] if images else []
 
121
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
122
  Image.open(BytesIO(image_data)).save(tmp.name)
123
  output_path = tmp.name
124
+
125
  can_create_video = bool(images and len(images) == 1)
126
+ can_extend_video = False
127
+ if can_create_video and previous_video_path and last_frame_path:
128
+ # The crucial check for continuity
129
+ if images[0][0] == last_frame_path:
130
+ can_extend_video = True
131
+
132
  return (output_path, gr.update(visible=can_create_video), gr.update(visible=can_extend_video), gr.update(visible=False))
133
  except Exception as e:
134
  raise gr.Error(f"Image generation failed: {e}")
 
138
  if not input_image_gallery or not output_image: raise gr.Error("Input/output images required.")
139
  try:
140
  new_segment_path = _generate_video_segment(input_image_gallery[0][0], output_image, prompt_input, oauth_token.token)
141
+ return new_segment_path, new_segment_path, output_image
142
  except Exception as e:
143
  raise gr.Error(f"Video creation failed: {e}")
144
 
 
153
  new_segment_path = _generate_video_segment(resized_input_path, resized_output_path, prompt_input, oauth_token.token)
154
  trimmed_segment_path = _trim_first_frame_fast(new_segment_path)
155
  final_video_path = _combine_videos_simple(previous_video_path, trimmed_segment_path)
156
+ return final_video_path, final_video_path, output_image
157
  except Exception as e:
158
  raise gr.Error(f"Video extension failed: {e}")
159
 
 
180
  gr.HTML("<h3 style='text-align:center'>Hugging Face PRO users can use Google's Nano Banana (Gemini 2.5 Flash Image Preview) on this Space. <a href='http://huggingface.co/subscribe/pro?source=nana_banana' target='_blank'>Subscribe to PRO</a></h3>", elem_id="sub_title")
181
  pro_message = gr.Markdown(visible=False)
182
  main_interface = gr.Column(visible=False)
183
+
184
  previous_video_state = gr.State(None)
185
+ last_frame_of_video_state = gr.State(None)
186
 
187
  with main_interface:
188
  with gr.Row():
 
206
  gr.on(
207
  triggers=[generate_button.click, prompt_input.submit],
208
  fn=unified_image_generator,
209
+ inputs=[prompt_input, image_input_gallery, previous_video_state, last_frame_of_video_state],
210
  outputs=[output_image, create_video_button, extend_video_button, video_group]
211
  )
212
  use_image_button.click(
 
222
  ).then(
223
  fn=create_new_video,
224
  inputs=[image_input_gallery, prompt_input, output_image],
225
+ outputs=[video_output, previous_video_state, last_frame_of_video_state],
226
  )
227
  extend_video_button.click(
228
  fn=lambda: gr.update(visible=True), outputs=[video_group]
229
  ).then(
230
  fn=extend_existing_video,
231
  inputs=[image_input_gallery, prompt_input, output_image, previous_video_state],
232
+ outputs=[video_output, previous_video_state, last_frame_of_video_state],
233
  )
234
 
235
  def control_access(profile: Optional[gr.OAuthProfile] = None, oauth_token: Optional[gr.OAuthToken] = None):