|
from dataclasses import dataclass |
|
from pathlib import Path |
|
import pathlib |
|
from typing import Dict, Any, Optional, Tuple |
|
import asyncio |
|
import base64 |
|
import io |
|
import logging |
|
import random |
|
import traceback |
|
import os |
|
import numpy as np |
|
import torch |
|
from diffusers import LTXPipeline, LTXImageToVideoPipeline |
|
from PIL import Image |
|
|
|
from varnish import Varnish |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MAX_WIDTH = 1280 |
|
MAX_HEIGHT = 720 |
|
MAX_FRAMES = 257 |
|
|
|
|
|
def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory): |
|
""" |
|
Recursively rename all '.wut' files to '.pth' in the given directory |
|
|
|
Args: |
|
directory (str): Path to the directory to process |
|
""" |
|
|
|
directory = os.path.abspath(directory) |
|
|
|
|
|
for root, _, files in os.walk(directory): |
|
for filename in files: |
|
if filename.endswith('.wut'): |
|
|
|
old_path = os.path.join(root, filename) |
|
|
|
new_filename = filename.replace('.wut', '.pth') |
|
new_path = os.path.join(root, new_filename) |
|
|
|
try: |
|
os.rename(old_path, new_path) |
|
print(f"Renamed: {old_path} -> {new_path}") |
|
except OSError as e: |
|
print(f"Error renaming {old_path}: {e}") |
|
|
|
def print_directory_structure(startpath): |
|
"""Print the directory structure starting from the given path.""" |
|
for root, dirs, files in os.walk(startpath): |
|
level = root.replace(startpath, '').count(os.sep) |
|
indent = ' ' * 4 * level |
|
logger.info(f"{indent}{os.path.basename(root)}/") |
|
subindent = ' ' * 4 * (level + 1) |
|
for f in files: |
|
logger.info(f"{subindent}{f}") |
|
|
|
logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):") |
|
apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository") |
|
|
|
logger.info("💡 Printing directory structure of ""/repository"":") |
|
print_directory_structure("/repository") |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
"""Configuration for video generation""" |
|
width: int = 768 |
|
height: int = 512 |
|
fps: int = 24 |
|
duration_sec: float = 4.0 |
|
num_inference_steps: int = 30 |
|
guidance_scale: float = 7.5 |
|
upscale_factor: float = 2.0 |
|
enable_interpolation: bool = False |
|
seed: int = -1 |
|
|
|
@property |
|
def num_frames(self) -> int: |
|
"""Calculate number of frames based on fps and duration""" |
|
return int(self.duration_sec * self.fps) + 1 |
|
|
|
def validate_and_adjust(self) -> 'GenerationConfig': |
|
"""Validate and adjust parameters to meet constraints""" |
|
|
|
self.width = max(32, min(MAX_WIDTH, round(self.width / 32) * 32)) |
|
self.height = max(32, min(MAX_HEIGHT, round(self.height / 32) * 32)) |
|
|
|
|
|
k = (self.num_frames - 1) // 8 |
|
num_frames = min((k * 8) + 1, MAX_FRAMES) |
|
self.duration_sec = (num_frames - 1) / self.fps |
|
|
|
|
|
if self.seed == -1: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
|
|
return self |
|
|
|
class EndpointHandler: |
|
"""Handles video generation requests using LTX models and Varnish post-processing""" |
|
|
|
def __init__(self, model_path: str = ""): |
|
"""Initialize the handler with LTX models and Varnish |
|
|
|
Args: |
|
model_path: Path to LTX model weights |
|
""" |
|
|
|
|
|
|
|
|
|
self.text_to_video = LTXPipeline.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
self.image_to_video = LTXImageToVideoPipeline.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.varnish = Varnish( |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
output_format="mp4", |
|
output_codec="h264", |
|
output_quality=23, |
|
enable_mmaudio=False, |
|
|
|
model_base_dir="/repository/varnish", |
|
) |
|
|
|
async def process_frames( |
|
self, |
|
frames: torch.Tensor, |
|
config: GenerationConfig |
|
) -> tuple[str, dict]: |
|
"""Post-process generated frames using Varnish |
|
|
|
Args: |
|
frames: Generated video frames tensor |
|
config: Generation configuration |
|
|
|
Returns: |
|
Tuple of (video data URI, metadata dictionary) |
|
""" |
|
|
|
result = await self.varnish( |
|
input_data=frames, |
|
input_fps=config.fps, |
|
upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None, |
|
enable_interpolation=config.enable_interpolation, |
|
output_fps=config.fps |
|
) |
|
|
|
|
|
video_uri = await result.write( |
|
output_type="data-uri", |
|
output_format="mp4", |
|
output_codec="h264", |
|
output_quality=23 |
|
) |
|
|
|
|
|
metadata = { |
|
"width": result.metadata.width, |
|
"height": result.metadata.height, |
|
"num_frames": result.metadata.frame_count, |
|
"fps": result.metadata.fps, |
|
"duration": result.metadata.duration, |
|
"num_inference_steps": config.num_inference_steps, |
|
"seed": config.seed, |
|
"upscale_factor": config.upscale_factor, |
|
"interpolation_enabled": config.enable_interpolation |
|
} |
|
|
|
return video_uri, metadata |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process incoming requests for video generation |
|
|
|
Args: |
|
data: Request data containing: |
|
- inputs (str): Text prompt or image |
|
- width (optional): Video width |
|
- height (optional): Video height |
|
- fps (optional): Frames per second |
|
- duration_sec (optional): Video duration |
|
- num_inference_steps (optional): Inference steps |
|
- guidance_scale (optional): Guidance scale |
|
- upscale_factor (optional): Upscaling factor |
|
- enable_interpolation (optional): Enable frame interpolation |
|
- seed (optional): Random seed |
|
|
|
Returns: |
|
Dictionary containing: |
|
- video: Base64 encoded MP4 data URI |
|
- content-type: MIME type |
|
- metadata: Generation metadata |
|
""" |
|
|
|
prompt = data.get("inputs") |
|
if not prompt: |
|
raise ValueError("No prompt provided in the 'inputs' field") |
|
|
|
|
|
config = GenerationConfig( |
|
width=data.get("width", GenerationConfig.width), |
|
height=data.get("height", GenerationConfig.height), |
|
fps=data.get("fps", GenerationConfig.fps), |
|
duration_sec=data.get("duration_sec", GenerationConfig.duration_sec), |
|
num_inference_steps=data.get("num_inference_steps", GenerationConfig.num_inference_steps), |
|
guidance_scale=data.get("guidance_scale", GenerationConfig.guidance_scale), |
|
upscale_factor=data.get("upscale_factor", GenerationConfig.upscale_factor), |
|
enable_interpolation=data.get("enable_interpolation", GenerationConfig.enable_interpolation), |
|
seed=data.get("seed", GenerationConfig.seed) |
|
).validate_and_adjust() |
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
random.seed(config.seed) |
|
np.random.seed(config.seed) |
|
generator = torch.manual_seed(config.seed) |
|
|
|
|
|
generation_kwargs = { |
|
"prompt": prompt, |
|
"height": config.height, |
|
"width": config.width, |
|
"num_frames": config.num_frames, |
|
"guidance_scale": config.guidance_scale, |
|
"num_inference_steps": config.num_inference_steps, |
|
"output_type": "pt", |
|
"generator": generator |
|
} |
|
|
|
|
|
image_data = data.get("image") |
|
if image_data: |
|
|
|
if image_data.startswith('data:'): |
|
image_data = image_data.split(',', 1)[1] |
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
generation_kwargs["image"] = image |
|
frames = self.image_to_video(**generation_kwargs).frames |
|
else: |
|
frames = self.text_to_video(**generation_kwargs).frames |
|
|
|
|
|
logger.info(f"Original frames shape: {frames.shape}") |
|
|
|
|
|
if len(frames.shape) == 5: |
|
frames = frames.squeeze(0) |
|
|
|
logger.info(f"Processed frames shape: {frames.shape}") |
|
|
|
|
|
if len(frames.shape) != 4: |
|
raise ValueError(f"Expected tensor of shape [frames, channels, height, width], got shape {frames.shape}") |
|
|
|
|
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
video_uri, metadata = loop.run_until_complete( |
|
self.process_frames(frames, config) |
|
) |
|
|
|
return { |
|
"video": video_uri, |
|
"content-type": "video/mp4", |
|
"metadata": metadata |
|
} |
|
|
|
except Exception as e: |
|
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" |
|
print(message) |
|
raise RuntimeError(message) |