File size: 9,487 Bytes
ef15707
132e8c4
 
 
 
 
1a6f91c
 
 
 
e349e43
 
 
 
 
132e8c4
 
ef15707
 
 
 
 
 
 
 
 
 
 
 
 
d35cde0
 
 
132e8c4
 
 
 
 
 
d35cde0
 
 
132e8c4
 
 
 
 
 
 
 
 
 
 
d35cde0
 
 
ef15707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6f91c
ef15707
1a6f91c
ef15707
1a6f91c
 
 
e349e43
1a6f91c
 
 
 
 
e349e43
ef15707
e349e43
 
 
ef15707
e349e43
1a6f91c
 
 
e349e43
 
1a6f91c
 
 
 
 
 
 
ef15707
1a6f91c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132e8c4
 
 
 
 
 
 
 
ef15707
 
 
1a6f91c
ef15707
132e8c4
 
 
 
d35cde0
 
ef15707
132e8c4
d35cde0
 
132e8c4
d35cde0
132e8c4
ef15707
d35cde0
 
ef15707
 
 
d35cde0
 
ef15707
 
 
d35cde0
 
 
 
132e8c4
e349e43
ef15707
 
e349e43
132e8c4
1a6f91c
ef15707
 
 
 
 
 
 
 
 
 
 
 
1a6f91c
 
 
 
e349e43
ef15707
 
1a6f91c
e349e43
ef15707
132e8c4
1a6f91c
 
 
 
 
132e8c4
d35cde0
 
 
 
 
1a6f91c
d35cde0
 
ef15707
 
 
 
 
 
 
 
1a6f91c
132e8c4
 
e349e43
1a6f91c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from typing import Dict, Any, Union, Optional, Tuple
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io
import tempfile
import numpy as np
from moviepy.editor import ImageSequenceClip
import os
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    # Default configuration
    DEFAULT_FPS = 24
    DEFAULT_DURATION = 4  # seconds
    DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1  # 97 frames
    DEFAULT_NUM_STEPS = 25
    DEFAULT_WIDTH = 768
    DEFAULT_HEIGHT = 512
    
    # Constraints
    MAX_WIDTH = 1280
    MAX_HEIGHT = 720
    MAX_FRAMES = 257
    
    ENABLE_CPU_OFFLOAD = True
    EXPERIMENTAL_STUFF = False
    
    def __init__(self, path: str = ""):
        """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
        
        Args:
            path (str): Path to the model weights directory
        """
        if EXPERIMENTAL_STUFF:
            torch.backends.cuda.matmul.allow_tf32 = True
        
        # Load both pipelines with bfloat16 precision as recommended in docs
        self.text_to_video = LTXPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
        self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")

        if ENABLE_CPU_OFFLOAD:
            self.text_to_video.enable_model_cpu_offload()
            self.image_to_video.enable_model_cpu_offload()

    def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
        """Validate and adjust resolution to meet constraints.
        
        Args:
            width (int): Requested width
            height (int): Requested height
            
        Returns:
            Tuple[int, int]: Adjusted (width, height)
        """
        # Round to nearest multiple of 32
        width = round(width / 32) * 32
        height = round(height / 32) * 32
        
        # Enforce maximum dimensions
        width = min(width, self.MAX_WIDTH)
        height = min(height, self.MAX_HEIGHT)
        
        # Enforce minimum dimensions
        width = max(width, 32)
        height = max(height, 32)
        
        return width, height

    def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]:
        """Validate and adjust frame count and FPS to meet constraints.
        
        Args:
            num_frames (Optional[int]): Requested number of frames
            fps (Optional[int]): Requested frames per second
            
        Returns:
            Tuple[int, int]: Adjusted (num_frames, fps)
        """
        # Use defaults if not provided
        fps = fps or self.DEFAULT_FPS
        num_frames = num_frames or self.DEFAULT_NUM_FRAMES
        
        # Adjust frames to be in format 8k + 1
        k = (num_frames - 1) // 8
        num_frames = (k * 8) + 1
        
        # Enforce maximum frame count
        num_frames = min(num_frames, self.MAX_FRAMES)
        
        return num_frames, fps

    def _create_video_file(self, frames: torch.Tensor, fps: int = DEFAULT_FPS) -> bytes:
        """Convert frames to an MP4 video file.
        
        Args:
            frames (torch.Tensor): Generated frames tensor
            fps (int): Frames per second for the output video
            
        Returns:
            bytes: MP4 video file content
        """
        # Log frame information
        num_frames = frames.shape[1]
        duration = num_frames / fps
        logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")

        # Convert tensor to numpy array
        video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
        video_np = (video_np * 255).astype(np.uint8)
        
        # Get dimensions
        _, height, width, _ = video_np.shape
        logger.info(f"Video dimensions: {width}x{height}")
        
        # Create temporary file
        output_path = tempfile.mktemp(suffix=".mp4")
        
        try:
            # Create video clip and write to file
            clip = ImageSequenceClip(list(video_np), fps=fps)
            clip.write_videofile(output_path, codec="libx264", audio=False)
            
            # Read the video file
            with open(output_path, "rb") as f:
                video_content = f.read()
                
            return video_content
            
        finally:
            # Cleanup
            if os.path.exists(output_path):
                os.remove(output_path)
            
            # Clear memory
            del video_np
            torch.cuda.empty_cache()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process the input data and generate video using LTX.
        
        Args:
            data (Dict[str, Any]): Input data containing:
                - prompt (str): Text description for video generation
                - image (Optional[str]): Base64 encoded image for image-to-video generation
                - width (Optional[int]): Video width (default: 768)
                - height (Optional[int]): Video height (default: 512)
                - num_frames (Optional[int]): Number of frames (default: 97)
                - fps (Optional[int]): Frames per second (default: 24)
                - num_inference_steps (Optional[int]): Number of inference steps (default: 25)
                - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
        
        Returns:
            Dict[str, Any]: Dictionary containing:
                - video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:").
                - content-type: MIME type of the video (right now always "video/mp4")
                - metadata: Dictionary with actual values used for generation
        """
        # Get inputs from request data
        prompt = data.pop("inputs", None)
        if not prompt:
            raise ValueError("No prompt provided in the 'inputs' field")

        # Get and validate resolution
        width = data.pop("width", self.DEFAULT_WIDTH)
        height = data.pop("height", self.DEFAULT_HEIGHT)
        width, height = self._validate_and_adjust_resolution(width, height)

        # Get and validate frames and FPS
        num_frames = data.pop("num_frames", self.DEFAULT_NUM_FRAMES)
        fps = data.pop("fps", self.DEFAULT_FPS)
        num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)

        # Get other parameters with defaults
        guidance_scale = data.pop("guidance_scale", 7.5)
        num_inference_steps = data.pop("num_inference_steps", self.DEFAULT_NUM_STEPS)
        seed = data.pop("seed", -1)
        seed = None if seed == -1 else int(seed)

        logger.info(f"Generating video with prompt: '{prompt}'")
        logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}")
        logger.info(f"Additional params: guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")

        try:
            with torch.no_grad():
                generation_kwargs = {
                    "prompt": prompt,
                    "height": height,
                    "width": width,
                    "num_frames": num_frames,
                    "guidance_scale": guidance_scale,
                    "num_inference_steps": num_inference_steps,
                    "output_type": "pt"
                }

                # Check if image is provided for image-to-video generation
                image_data = data.get("image")
                if image_data:
                    # Decode base64 image
                    image_bytes = base64.b64decode(image_data)
                    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                    logger.info("Using image-to-video generation mode")
                    generation_kwargs["image"] = image
                    output = self.image_to_video(**generation_kwargs).frames
                else:
                    logger.info("Using text-to-video generation mode")
                    output = self.text_to_video(**generation_kwargs).frames

                # Convert frames to video file
                video_content = self._create_video_file(output, fps=fps)
                
                # Encode video to base64
                video_base64 = base64.b64encode(video_content).decode('utf-8')

                content_type = "video/mp4"
                
                # Add MP4 data URI prefix
                video_data_uri = f"data:{content_type};base64,{video_base64}"
            
                return {
                    "video": video_data_uri,
                    "content-type": content_type,
                    "metadata": {
                        "width": width,
                        "height": height,
                        "num_frames": num_frames,
                        "fps": fps,
                        "duration": num_frames / fps,
                        "num_inference_steps": num_inference_steps
                    }
                }

        except Exception as e:
            logger.error(f"Error generating video: {str(e)}")
            raise RuntimeError(f"Error generating video: {str(e)}")