1inkusFace commited on
Commit
df19679
·
verified ·
1 Parent(s): 5fddbe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -21
app.py CHANGED
@@ -7,6 +7,12 @@ import random
7
  from PIL import Image
8
  import torch
9
  import asyncio # Import asyncio
 
 
 
 
 
 
10
 
11
  # os.environ["CUDA_VISIBLE_DEVICES"] = "" # Uncomment if needed
12
  os.environ["SAFETENSORS_FAST_GPU"] = "1"
@@ -18,13 +24,7 @@ os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
18
  predictor_state = gr.State(None)
19
  device="cuda:0" if torch.cuda.is_available() else "cpu" # Pass device to the constructor
20
 
21
-
22
  def init_predictor(task_type: str):
23
- from skyreelsinfer import TaskType
24
- from skyreelsinfer.offload import OffloadConfig
25
- from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
26
- from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
27
-
28
  try:
29
  predictor = SkyReelsVideoInfer(
30
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
@@ -45,22 +45,16 @@ def init_predictor(task_type: str):
45
  print(f"Error loading model: {e}")
46
  return None
47
 
48
-
49
  # Make generate_video async
50
  async def generate_video(prompt, image_file, predictor):
51
- from diffusers.utils import export_to_video
52
- from diffusers.utils import load_image
53
-
54
  if image_file is None:
55
  return gr.Error("Error: For i2v, provide an image.")
56
  if not isinstance(prompt, str) or not prompt.strip():
57
  return gr.Error("Error: Please provide a prompt.")
58
  if predictor is None:
59
  return gr.Error("Error: Model not loaded.")
60
-
61
  random.seed(time.time())
62
  seed = int(random.randrange(4294967294))
63
-
64
  kwargs = {
65
  "prompt": prompt,
66
  "height": 256,
@@ -73,7 +67,6 @@ async def generate_video(prompt, image_file, predictor):
73
  "negative_prompt": "bad quality, blur",
74
  "cfg_for": False,
75
  }
76
-
77
  try:
78
  # Load the image and move it to the correct device *before* inference
79
  image = load_image(image=image_file.name)
@@ -81,26 +74,21 @@ async def generate_video(prompt, image_file, predictor):
81
  kwargs["image"] = image
82
  except Exception as e:
83
  return gr.Error(f"Image loading error: {e}")
84
-
85
  try:
86
  output = predictor.inference(kwargs)
87
  frames = output
88
  except Exception as e:
89
  return gr.Error(f"Inference error: {e}"), None # Return None for predictor on error
90
-
91
  save_dir = "./result/i2v" # Consistent directory
92
  os.makedirs(save_dir, exist_ok=True)
93
  video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
94
  print(f"Generating video: {video_out_file}")
95
-
96
  try:
97
  export_to_video(frames, video_out_file, fps=24)
98
  except Exception as e:
99
  return gr.Error(f"Video export error: {e}"), None # Return None for predictor
100
-
101
  return video_out_file, predictor # Return updated predictor
102
 
103
-
104
  def display_image(file):
105
  if file is not None:
106
  return Image.open(file.name)
@@ -118,20 +106,17 @@ async def main():
118
  prompt_textbox = gr.Text(label="Prompt")
119
  generate_button = gr.Button("Generate")
120
  output_video = gr.Video(label="Output Video")
121
-
122
  image_file.change(
123
  display_image,
124
  inputs=[image_file],
125
  outputs=[image_file_preview]
126
  )
127
-
128
  generate_button.click(
129
  fn=generate_video,
130
  inputs=[prompt_textbox, image_file, predictor_state],
131
  outputs=[output_video, predictor_state], # Output predictor_state
132
  )
133
  predictor_state.value = await load_model() # load and set predictor
134
-
135
  await demo.launch()
136
 
137
  if __name__ == "__main__":
 
7
  from PIL import Image
8
  import torch
9
  import asyncio # Import asyncio
10
+ from skyreelsinfer import TaskType
11
+ from skyreelsinfer.offload import OffloadConfig
12
+ from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
13
+ from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
14
+ from diffusers.utils import export_to_video
15
+ from diffusers.utils import load_image
16
 
17
  # os.environ["CUDA_VISIBLE_DEVICES"] = "" # Uncomment if needed
18
  os.environ["SAFETENSORS_FAST_GPU"] = "1"
 
24
  predictor_state = gr.State(None)
25
  device="cuda:0" if torch.cuda.is_available() else "cpu" # Pass device to the constructor
26
 
 
27
  def init_predictor(task_type: str):
 
 
 
 
 
28
  try:
29
  predictor = SkyReelsVideoInfer(
30
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
 
45
  print(f"Error loading model: {e}")
46
  return None
47
 
 
48
  # Make generate_video async
49
  async def generate_video(prompt, image_file, predictor):
 
 
 
50
  if image_file is None:
51
  return gr.Error("Error: For i2v, provide an image.")
52
  if not isinstance(prompt, str) or not prompt.strip():
53
  return gr.Error("Error: Please provide a prompt.")
54
  if predictor is None:
55
  return gr.Error("Error: Model not loaded.")
 
56
  random.seed(time.time())
57
  seed = int(random.randrange(4294967294))
 
58
  kwargs = {
59
  "prompt": prompt,
60
  "height": 256,
 
67
  "negative_prompt": "bad quality, blur",
68
  "cfg_for": False,
69
  }
 
70
  try:
71
  # Load the image and move it to the correct device *before* inference
72
  image = load_image(image=image_file.name)
 
74
  kwargs["image"] = image
75
  except Exception as e:
76
  return gr.Error(f"Image loading error: {e}")
 
77
  try:
78
  output = predictor.inference(kwargs)
79
  frames = output
80
  except Exception as e:
81
  return gr.Error(f"Inference error: {e}"), None # Return None for predictor on error
 
82
  save_dir = "./result/i2v" # Consistent directory
83
  os.makedirs(save_dir, exist_ok=True)
84
  video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
85
  print(f"Generating video: {video_out_file}")
 
86
  try:
87
  export_to_video(frames, video_out_file, fps=24)
88
  except Exception as e:
89
  return gr.Error(f"Video export error: {e}"), None # Return None for predictor
 
90
  return video_out_file, predictor # Return updated predictor
91
 
 
92
  def display_image(file):
93
  if file is not None:
94
  return Image.open(file.name)
 
106
  prompt_textbox = gr.Text(label="Prompt")
107
  generate_button = gr.Button("Generate")
108
  output_video = gr.Video(label="Output Video")
 
109
  image_file.change(
110
  display_image,
111
  inputs=[image_file],
112
  outputs=[image_file_preview]
113
  )
 
114
  generate_button.click(
115
  fn=generate_video,
116
  inputs=[prompt_textbox, image_file, predictor_state],
117
  outputs=[output_video, predictor_state], # Output predictor_state
118
  )
119
  predictor_state.value = await load_model() # load and set predictor
 
120
  await demo.launch()
121
 
122
  if __name__ == "__main__":