hysts HF staff commited on
Commit
2d41ed4
1 Parent(s): 8a407e4

flash-attn

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,7 +1,8 @@
 
 
1
  import subprocess
2
 
3
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
4
- shell=True)
5
 
6
  import gradio as gr
7
  import torch
@@ -42,7 +43,7 @@ model.vae.enable_tiling()
42
 
43
  def generate_video(prompt, height, width, duration, guidance_scale, video_guidance_scale):
44
  temp = 16 if duration == "5s" else 31
45
-
46
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
47
  frames = model.generate(
48
  prompt=prompt,
@@ -55,7 +56,7 @@ def generate_video(prompt, height, width, duration, guidance_scale, video_guidan
55
  video_guidance_scale=video_guidance_scale,
56
  output_type="pil",
57
  )
58
-
59
  output_path = "generated_video.mp4"
60
  export_to_video(frames, output_path, fps=24)
61
  return output_path
@@ -70,7 +71,7 @@ def generate_video_from_image(image, prompt, video_guidance_scale):
70
  video_guidance_scale=video_guidance_scale,
71
  output_type="pil",
72
  )
73
-
74
  output_path = "generated_video_from_image.mp4"
75
  export_to_video(frames, output_path, fps=24)
76
  return output_path
@@ -78,7 +79,7 @@ def generate_video_from_image(image, prompt, video_guidance_scale):
78
  # Gradio interface
79
  with gr.Blocks() as demo:
80
  gr.Markdown("# Pyramid Flow Video Generation Demo")
81
-
82
  with gr.Tab("Text-to-Video"):
83
  with gr.Row():
84
  with gr.Column():
@@ -91,7 +92,7 @@ with gr.Blocks() as demo:
91
  txt_generate = gr.Button("Generate Video")
92
  with gr.Column():
93
  txt_output = gr.Video(label="Generated Video")
94
-
95
  with gr.Tab("Image-to-Video"):
96
  with gr.Row():
97
  with gr.Column():
@@ -101,13 +102,13 @@ with gr.Blocks() as demo:
101
  img_generate = gr.Button("Generate Video")
102
  with gr.Column():
103
  img_output = gr.Video(label="Generated Video")
104
-
105
- txt_generate.click(generate_video,
106
- inputs=[txt_prompt, txt_height, txt_width, txt_duration, txt_guidance_scale, txt_video_guidance_scale],
107
  outputs=txt_output)
108
-
109
  img_generate.click(generate_video_from_image,
110
  inputs=[img_input, img_prompt, img_video_guidance_scale],
111
  outputs=img_output)
112
 
113
- demo.launch()
 
1
+ import os
2
+ import shlex
3
  import subprocess
4
 
5
+ subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'), env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
 
6
 
7
  import gradio as gr
8
  import torch
 
43
 
44
  def generate_video(prompt, height, width, duration, guidance_scale, video_guidance_scale):
45
  temp = 16 if duration == "5s" else 31
46
+
47
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
48
  frames = model.generate(
49
  prompt=prompt,
 
56
  video_guidance_scale=video_guidance_scale,
57
  output_type="pil",
58
  )
59
+
60
  output_path = "generated_video.mp4"
61
  export_to_video(frames, output_path, fps=24)
62
  return output_path
 
71
  video_guidance_scale=video_guidance_scale,
72
  output_type="pil",
73
  )
74
+
75
  output_path = "generated_video_from_image.mp4"
76
  export_to_video(frames, output_path, fps=24)
77
  return output_path
 
79
  # Gradio interface
80
  with gr.Blocks() as demo:
81
  gr.Markdown("# Pyramid Flow Video Generation Demo")
82
+
83
  with gr.Tab("Text-to-Video"):
84
  with gr.Row():
85
  with gr.Column():
 
92
  txt_generate = gr.Button("Generate Video")
93
  with gr.Column():
94
  txt_output = gr.Video(label="Generated Video")
95
+
96
  with gr.Tab("Image-to-Video"):
97
  with gr.Row():
98
  with gr.Column():
 
102
  img_generate = gr.Button("Generate Video")
103
  with gr.Column():
104
  img_output = gr.Video(label="Generated Video")
105
+
106
+ txt_generate.click(generate_video,
107
+ inputs=[txt_prompt, txt_height, txt_width, txt_duration, txt_guidance_scale, txt_video_guidance_scale],
108
  outputs=txt_output)
109
+
110
  img_generate.click(generate_video_from_image,
111
  inputs=[img_input, img_prompt, img_video_guidance_scale],
112
  outputs=img_output)
113
 
114
+ demo.launch()