Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
import multiprocessing | |
import subprocess | |
import nltk | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import gc | |
from huggingface_hub import snapshot_download, hf_hub_download | |
from typing import List | |
import shutil | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, CLIPFeatureExtractor | |
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
from diffusers.utils import export_to_video | |
from moviepy.editor import VideoFileClip, CompositeVideoClip, TextClip | |
import moviepy.editor as mpy | |
from PIL import Image, ImageDraw, ImageFont | |
from mutagen.mp3 import MP3 | |
from gtts import gTTS | |
from pydub import AudioSegment | |
import uuid | |
from safetensors.torch import load_file | |
import textwrap | |
# ------------------------------------------------------------------- | |
# No more ImageMagick dependency! | |
# ------------------------------------------------------------------- | |
print("ImageMagick dependency removed. Using Pillow for text rendering.") | |
# Ensure NLTK’s 'punkt_tab' (and other data) is present | |
nltk.download('punkt_tab', quiet=True) | |
nltk.download('punkt', quiet=True) | |
# ------------------------------------------------------------------- | |
# GPU / Environment Setup | |
# ------------------------------------------------------------------- | |
def log_gpu_memory(): | |
"""Log GPU memory usage.""" | |
if torch.cuda.is_available(): | |
print(subprocess.check_output('nvidia-smi').decode('utf-8')) | |
else: | |
print("CUDA is not available. Cannot log GPU memory.") | |
def check_gpu_availability(): | |
"""Print GPU availability and device details.""" | |
if torch.cuda.is_available(): | |
print(f"CUDA devices: {torch.cuda.device_count()}") | |
print(f"Current device: {torch.cuda.current_device()}") | |
print(torch.cuda.get_device_properties(torch.cuda.current_device())) | |
else: | |
print("CUDA is not available. Running on CPU.") | |
check_gpu_availability() | |
# Ensure proper multiprocessing start method | |
multiprocessing.set_start_method("spawn", force=True) | |
# ------------------------------------------------------------------- | |
# Constants & Model Setup | |
# ------------------------------------------------------------------- | |
dtype = torch.float16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE_720 = 720 # Changed maximum image size to 720, now max resolution is 720p | |
MAX_IMAGE_SIZE = MAX_IMAGE_SIZE_720 | |
RESOLUTIONS = { | |
"16:9": [ | |
{"resolution": "360p", "width": 640, "height": 360}, | |
{"resolution": "480p", "width": 854, "height": 480}, | |
{"resolution": "720p", "width": 1280, "height": 720}, | |
#{"resolution": "1080p", "width": 1920, "height": 1080} # Commented out resolutions higher than 720p | |
], | |
"4:3": [ | |
{"resolution": "360p", "width": 480, "height": 360}, | |
{"resolution": "480p", "width": 640, "height": 480}, | |
{"resolution": "720p", "width": 960, "height": 720}, | |
#{"resolution": "1080p", "width": 1440, "height": 1080} # Commented out resolutions higher than 720p | |
], | |
"1:1": [ | |
{"resolution": "360p", "width": 360, "height": 360}, | |
{"resolution": "480p", "width": 480, "height": 480}, | |
{"resolution": "720p", "width": 720, "height": 720}, | |
#{"resolution": "1080p", "width": 1080, "height": 1080}, # Commented out resolutions higher than 720p | |
#{"resolution": "1920p", "width": 1920, "height": 1920} # Commented out resolutions higher than 720p | |
], | |
"9:16": [ | |
{"resolution": "360p", "width": 360, "height": 640}, | |
{"resolution": "480p", "width": 480, "height": 854}, | |
{"resolution": "720p", "width": 720, "height": 1280}, | |
#{"resolution": "1080p", "width": 1080, "height": 1920} # Commented out resolutions higher than 720p | |
]} | |
DESCRIPTION = ( | |
"Video Story Generator with Audio\n" | |
"PS: Generation of video by using Artificial Intelligence via AnimateDiff, DistilBART, and GTTS." | |
) | |
TITLE = "Video Story Generator with Audio (AnimateDiff, DistilBART, and GTTS)" | |
def load_text_summarization_model(): | |
"""Load the tokenizer and model for text summarization on GPU/CPU.""" | |
print("Loading text summarization model...") | |
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6") | |
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6") | |
return tokenizer, model | |
tokenizer, model = load_text_summarization_model() | |
# Base models for AnimateDiffLightning | |
bases = { | |
"Cartoon": "frankjoshua/toonyou_beta6", | |
"Realistic": "emilianJR/epiCRealism", | |
"3d": "Lykon/DreamShaper", | |
"Anime": "Yntec/mistoonAnime2" | |
} | |
# Keep track of what's loaded to avoid reloading each time | |
step_loaded = None | |
base_loaded = "Realistic" | |
motion_loaded = None | |
# Initialize AnimateDiff pipeline | |
if not torch.cuda.is_available(): | |
raise NotImplementedError("No GPU detected!") | |
pipe = AnimateDiffPipeline.from_pretrained( | |
bases[base_loaded], | |
torch_dtype=dtype | |
).to(device) | |
pipe.scheduler = EulerDiscreteScheduler.from_config( | |
pipe.scheduler.config, | |
timestep_spacing="trailing", | |
beta_schedule="linear" | |
) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
# ------------------------------------------------------------------- | |
# Function: Generate Short Animation | |
# ------------------------------------------------------------------- | |
def generate_short_animation( | |
prompt_text: str, | |
base: str = "Realistic", | |
motion: str = "", | |
step: int = 4, | |
seed: int = 42, | |
width: int = 512, | |
height: int = 512, | |
) -> str: | |
""" | |
Generates a short animated video (MP4) from a given prompt using AnimateDiffLightning. | |
Returns the local path to the resulting MP4. | |
""" | |
global step_loaded | |
global base_loaded | |
global motion_loaded | |
# 1) Possibly reload correct step weights | |
if step_loaded != step: | |
repo = "ByteDance/AnimateDiff-Lightning" | |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" | |
pipe.unet.load_state_dict( | |
load_file(hf_hub_download(repo, ckpt), device=device), | |
strict=False | |
) | |
step_loaded = step | |
# 2) Possibly reload the correct base model | |
if base_loaded != base: | |
pipe.unet.load_state_dict( | |
torch.load( | |
hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), | |
map_location=device | |
), | |
strict=False | |
) | |
base_loaded = base | |
# 3) Possibly unload/load motion LORA | |
if motion_loaded != motion: | |
pipe.unload_lora_weights() | |
if motion: | |
pipe.load_lora_weights(motion, adapter_name="motion") | |
pipe.set_adapters(["motion"], [0.7]) # weighting can be adjusted | |
motion_loaded = motion | |
# 4) Generate frames | |
print(f"[INFO] Generating short animation for prompt: '{prompt_text}' ...") | |
generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None | |
output = pipe( | |
prompt=prompt_text, | |
guidance_scale=1.2, | |
num_inference_steps=step, | |
generator=generator, | |
width=width, | |
height=height | |
) | |
# 5) Export frames to a short MP4 | |
short_mp4_path = f"short_{uuid.uuid4().hex}.mp4" | |
export_to_video(output.frames[0], short_mp4_path, fps=10) | |
return short_mp4_path | |
# ------------------------------------------------------------------- | |
# Function: Merge MP3 files | |
# ------------------------------------------------------------------- | |
def merge_audio_files(mp3_names: List[str]) -> str: | |
""" | |
Merges a list of MP3 files into a single MP3 file. | |
Returns the path to the merged MP3 file. | |
""" | |
combined = AudioSegment.empty() | |
for f_name in mp3_names: | |
audio = AudioSegment.from_mp3(f_name) | |
combined += audio | |
export_path = f"merged_audio_{uuid.uuid4().hex}.mp3" # Dynamic output path for merged audio | |
combined.export(export_path, format="mp3") | |
print(f"DEBUG: Audio files merged and saved to {export_path}") | |
return export_path | |
# ------------------------------------------------------------------- | |
# Function: Overlay Subtitles on a Video | |
# ------------------------------------------------------------------- | |
def add_subtitles_to_video(input_video_path: str, text: str, duration: float) -> str: | |
""" | |
Overlays `text` as subtitles over the entire `input_video_path` for `duration` seconds using Pillow. | |
Returns the path to the newly generated MP4 with subtitles. | |
""" | |
base_clip = VideoFileClip(input_video_path) | |
final_dur = max(duration, base_clip.duration) | |
def make_frame(t): | |
frame_pil = Image.fromarray(base_clip.get_frame(t)) | |
draw = ImageDraw.Draw(frame_pil) | |
try: | |
font = ImageFont.truetype("arial.ttf", 40) # Change the font size if needed | |
except IOError: | |
font = ImageFont.load_default() # Use default font if Arial is not found | |
# Correctly compute text size using `textbbox()` | |
bbox = draw.textbbox((0, 0), text, font=font) | |
textwidth, textheight = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
x = (frame_pil.width - textwidth) / 2 | |
y = frame_pil.height - 70 - textheight # Position at the bottom | |
draw.text((x, y), text, font=font, fill=(255, 255, 0)) # Yellow color | |
return np.array(frame_pil) | |
# Create the video clip without `size` argument | |
subtitled_clip = mpy.VideoClip(make_frame, duration=final_dur) | |
# Composite the subtitled clip over the original video | |
final_clip = CompositeVideoClip([base_clip, subtitled_clip.set_position((0, 0))]) | |
final_clip = final_clip.set_duration(final_dur) | |
out_path = f"sub_{uuid.uuid4().hex}.mp4" | |
final_clip.write_videofile(out_path, fps=24, logger=None) | |
# Cleanup | |
base_clip.close() | |
final_clip.close() | |
subtitled_clip.close() | |
return out_path | |
# ------------------------------------------------------------------- | |
# Main Function: Generate Output Video | |
# ------------------------------------------------------------------- | |
def get_output_video(text, base_model_name, motion_name, num_inference_steps_backend, randomize_seed, seed, width, height): | |
""" | |
Summarize the user prompt, generate a short animated video for each sentence, | |
overlay subtitles, merge all into a final video with a single audio track. | |
""" | |
print("DEBUG: Starting get_output_video function...") | |
# Summarize the input text | |
print("DEBUG: Summarizing text...") | |
device_local = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device_local) # Move summarization model to GPU/CPU as needed | |
inputs = tokenizer( | |
text, | |
max_length=1024, | |
truncation=True, | |
return_tensors="pt" | |
).to(device_local) | |
summary_ids = model.generate(inputs["input_ids"]) | |
summary = tokenizer.batch_decode( | |
summary_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
) | |
plot = list(summary[0].split('.')) # Split summary into sentences | |
print(f"DEBUG: Summary generated: {plot}") | |
# Prepare seed based on randomize_seed checkbox | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else seed | |
# We'll generate a short video for each sentence | |
# We'll also create an audio track for each sentence | |
short_videos = [] | |
mp3_names = [] | |
mp3_lengths = [] | |
result_no_audio = f"result_no_audio_{uuid.uuid4().hex}.mp4" # Dynamic filename for no audio video | |
movie_final = f'result_final_{uuid.uuid4().hex}.mp4' # Dynamic filename for final video | |
merged_audio_path = "" # To store merged audio path for cleanup | |
try: # Try-finally block to ensure cleanup | |
for i, sentence in enumerate(plot[:-1]): | |
# 1) Generate short video for this sentence | |
prompt_for_animation = f"Generate a realistic video about this: {sentence}" | |
print(f"DEBUG: Generating short video {i+1} of {len(plot)-1} ...") | |
short_mp4_path = generate_short_animation( | |
prompt_text=prompt_for_animation, | |
base=base_model_name, | |
motion=motion_name, | |
step=int(num_inference_steps_backend), | |
seed=current_seed + i, # Increment seed for each sentence for variation | |
width=width, | |
height=height | |
) | |
# 2) Generate audio for the sentence | |
audio_filename = f'audio_{uuid.uuid4().hex}_{i}.mp3' # Dynamic audio filename | |
tts_obj = gTTS(text=sentence, lang='en', slow=False) | |
tts_obj.save(audio_filename) | |
audio_info = MP3(audio_filename) | |
audio_duration = audio_info.info.length | |
mp3_names.append(audio_filename) | |
mp3_lengths.append(audio_duration) | |
# 3) Overlay subtitles on top of the short video (using Pillow now) | |
final_clip_duration = audio_duration + 0.5 # half-second pad | |
short_subtitled_path = add_subtitles_to_video( | |
input_video_path=short_mp4_path, | |
text=sentence.strip(), | |
duration=final_clip_duration | |
) | |
short_videos.append(short_subtitled_path) | |
# Clean up the original short clip (no subtitles) | |
os.remove(short_mp4_path) | |
# ---------------------------------------------------------------- | |
# Merge all MP3 files into one | |
# ---------------------------------------------------------------- | |
merged_audio_path = merge_audio_files(mp3_names) | |
# ---------------------------------------------------------------- | |
# Concatenate all short subtitled videos | |
# ---------------------------------------------------------------- | |
print("DEBUG: Concatenating all short videos into a single clip...") | |
clip_objects = [] | |
for vid_path in short_videos: | |
clip = mpy.VideoFileClip(vid_path) | |
clip_objects.append(clip) | |
final_concat = mpy.concatenate_videoclips(clip_objects, method="compose") | |
final_concat.write_videofile(result_no_audio, fps=24, logger=None) | |
# ---------------------------------------------------------------- | |
# Combine big video with merged audio | |
# ---------------------------------------------------------------- | |
def combine_audio(vidname, audname, outname, fps=24): | |
print(f"DEBUG: Combining audio for video: '{vidname}'") | |
my_clip = mpy.VideoFileClip(vidname) | |
audio_background = mpy.AudioFileClip(audname) | |
final_clip = my_clip.set_audio(audio_background) | |
final_clip.write_videofile(outname, fps=fps, logger=None) | |
my_clip.close() | |
final_clip.close() | |
combine_audio(result_no_audio, merged_audio_path, movie_final) | |
finally: # Cleanup always executes | |
print("DEBUG: Cleaning up temporary files...") | |
# Remove short subtitled videos | |
for path_ in short_videos: | |
os.remove(path_) | |
# Remove mp3 segments | |
for f_mp3 in mp3_names: | |
os.remove(f_mp3) | |
# Remove merged audio | |
if os.path.exists(merged_audio_path): | |
os.remove(merged_audio_path) | |
# Remove partial no-audio mp4 | |
if os.path.exists(result_no_audio): | |
os.remove(result_no_audio) | |
print("DEBUG: get_output_video function completed successfully.") | |
return movie_final | |
# ------------------------------------------------------------------- | |
# Example text (user can override) | |
# ------------------------------------------------------------------- | |
text = ( | |
"Once, there was a girl called Laura who went to the supermarket to buy the ingredients to make a cake. " | |
"Because today is her birthday and her friends come to her house and help her to prepare the cake." | |
) | |
# ------------------------------------------------------------------- | |
# Gradio Interface | |
# ------------------------------------------------------------------- | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown( | |
""" | |
# Video Generator ⚡ from stories with Artificial Intelligence | |
A story can be input by user. The story is summarized using DistilBART model. | |
Then, the images are generated by using AnimateDiff and AnimateDiff-Lightning, | |
and the subtitles and audio are created using gTTS. These are combined to generate a video. | |
**Credits**: Developed by [ruslanmv.com](https://ruslanmv.com). | |
""" | |
) | |
with gr.Group(): | |
with gr.Row(): | |
input_start_text = gr.Textbox(value=text, label='Prompt') | |
with gr.Row(): | |
select_base = gr.Dropdown( | |
label='Base model', | |
choices=["Cartoon", "Realistic", "3d", "Anime"], | |
value=base_loaded, | |
interactive=True | |
) | |
select_motion = gr.Dropdown( | |
label='Motion', | |
choices=[ | |
("Default", ""), | |
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"), | |
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"), | |
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"), | |
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"), | |
("Pan left", "guoyww/animatediff-motion-lora-pan-left"), | |
("Pan right", "guoyww/animatediff-motion-lora-pan-right"), | |
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"), | |
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"), | |
], | |
value="", # default: no motion lora | |
interactive=True | |
) | |
select_step = gr.Dropdown( | |
label='Inference steps', | |
choices=[('1-Step', 1), ('2-Step', 2), ('4-Step', 4), ('8-Step', 8)], | |
value=4, | |
interactive=True | |
) | |
button_gen_video = gr.Button( | |
scale=1, | |
variant='primary', | |
value="Generate Video" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE_720, # 제한 720 pixels maximum 사이즈, updated max size to 720p | |
step=1, | |
value=640, # Default width for 480p 4:3 | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE_720, # 제한 720 pixels maximum 사이즈, updated max size to 720p | |
step=1, | |
value=480, # Default height for 480p 4:3 | |
) | |
with gr.Column(): | |
#output_interpolation = gr.Video(label="Generated Video") | |
output_interpolation = gr.Video(value="video.mp4", label="Generated Video") # Set default video | |
button_gen_video.click( | |
fn=get_output_video, | |
inputs=[input_start_text, select_base, select_motion, select_step, randomize_seed, seed, width, height], | |
outputs=output_interpolation | |
) | |
demo.queue().launch(debug=True, share=False) |