Commit
•
d35cde0
1
Parent(s):
ef15707
Update handler.py
Browse files- handler.py +29 -16
handler.py
CHANGED
@@ -28,12 +28,18 @@ class EndpointHandler:
|
|
28 |
MAX_HEIGHT = 720
|
29 |
MAX_FRAMES = 257
|
30 |
|
|
|
|
|
|
|
31 |
def __init__(self, path: str = ""):
|
32 |
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
|
33 |
|
34 |
Args:
|
35 |
path (str): Path to the model weights directory
|
36 |
"""
|
|
|
|
|
|
|
37 |
# Load both pipelines with bfloat16 precision as recommended in docs
|
38 |
self.text_to_video = LTXPipeline.from_pretrained(
|
39 |
path,
|
@@ -45,9 +51,9 @@ class EndpointHandler:
|
|
45 |
torch_dtype=torch.bfloat16
|
46 |
).to("cuda")
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
|
53 |
"""Validate and adjust resolution to meet constraints.
|
@@ -158,28 +164,30 @@ class EndpointHandler:
|
|
158 |
|
159 |
Returns:
|
160 |
Dict[str, Any]: Dictionary containing:
|
161 |
-
- video:
|
162 |
-
- content-type: MIME type of the video
|
163 |
- metadata: Dictionary with actual values used for generation
|
164 |
"""
|
165 |
-
#
|
166 |
-
prompt = data.
|
167 |
if not prompt:
|
168 |
-
raise ValueError("
|
169 |
|
170 |
# Get and validate resolution
|
171 |
-
width = data.
|
172 |
-
height = data.
|
173 |
width, height = self._validate_and_adjust_resolution(width, height)
|
174 |
|
175 |
# Get and validate frames and FPS
|
176 |
-
num_frames = data.
|
177 |
-
fps = data.
|
178 |
num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
|
179 |
|
180 |
# Get other parameters with defaults
|
181 |
-
guidance_scale = data.
|
182 |
-
num_inference_steps = data.
|
|
|
|
|
183 |
|
184 |
logger.info(f"Generating video with prompt: '{prompt}'")
|
185 |
logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}")
|
@@ -216,9 +224,14 @@ class EndpointHandler:
|
|
216 |
# Encode video to base64
|
217 |
video_base64 = base64.b64encode(video_content).decode('utf-8')
|
218 |
|
|
|
|
|
|
|
|
|
|
|
219 |
return {
|
220 |
-
"video":
|
221 |
-
"content-type":
|
222 |
"metadata": {
|
223 |
"width": width,
|
224 |
"height": height,
|
|
|
28 |
MAX_HEIGHT = 720
|
29 |
MAX_FRAMES = 257
|
30 |
|
31 |
+
ENABLE_CPU_OFFLOAD = True
|
32 |
+
EXPERIMENTAL_STUFF = False
|
33 |
+
|
34 |
def __init__(self, path: str = ""):
|
35 |
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
|
36 |
|
37 |
Args:
|
38 |
path (str): Path to the model weights directory
|
39 |
"""
|
40 |
+
if EXPERIMENTAL_STUFF:
|
41 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
42 |
+
|
43 |
# Load both pipelines with bfloat16 precision as recommended in docs
|
44 |
self.text_to_video = LTXPipeline.from_pretrained(
|
45 |
path,
|
|
|
51 |
torch_dtype=torch.bfloat16
|
52 |
).to("cuda")
|
53 |
|
54 |
+
if ENABLE_CPU_OFFLOAD:
|
55 |
+
self.text_to_video.enable_model_cpu_offload()
|
56 |
+
self.image_to_video.enable_model_cpu_offload()
|
57 |
|
58 |
def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
|
59 |
"""Validate and adjust resolution to meet constraints.
|
|
|
164 |
|
165 |
Returns:
|
166 |
Dict[str, Any]: Dictionary containing:
|
167 |
+
- video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:").
|
168 |
+
- content-type: MIME type of the video (right now always "video/mp4")
|
169 |
- metadata: Dictionary with actual values used for generation
|
170 |
"""
|
171 |
+
# Get inputs from request data
|
172 |
+
prompt = data.pop("inputs", None)
|
173 |
if not prompt:
|
174 |
+
raise ValueError("No prompt provided in the 'inputs' field")
|
175 |
|
176 |
# Get and validate resolution
|
177 |
+
width = data.pop("width", self.DEFAULT_WIDTH)
|
178 |
+
height = data.pop("height", self.DEFAULT_HEIGHT)
|
179 |
width, height = self._validate_and_adjust_resolution(width, height)
|
180 |
|
181 |
# Get and validate frames and FPS
|
182 |
+
num_frames = data.pop("num_frames", self.DEFAULT_NUM_FRAMES)
|
183 |
+
fps = data.pop("fps", self.DEFAULT_FPS)
|
184 |
num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
|
185 |
|
186 |
# Get other parameters with defaults
|
187 |
+
guidance_scale = data.pop("guidance_scale", 7.5)
|
188 |
+
num_inference_steps = data.pop("num_inference_steps", self.DEFAULT_NUM_STEPS)
|
189 |
+
seed = data.pop("seed", -1)
|
190 |
+
seed = None if seed == -1 else int(seed)
|
191 |
|
192 |
logger.info(f"Generating video with prompt: '{prompt}'")
|
193 |
logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}")
|
|
|
224 |
# Encode video to base64
|
225 |
video_base64 = base64.b64encode(video_content).decode('utf-8')
|
226 |
|
227 |
+
content_type = "video/mp4"
|
228 |
+
|
229 |
+
# Add MP4 data URI prefix
|
230 |
+
video_data_uri = f"data:{content_type};base64,{video_base64}"
|
231 |
+
|
232 |
return {
|
233 |
+
"video": video_data_uri,
|
234 |
+
"content-type": content_type,
|
235 |
"metadata": {
|
236 |
"width": width,
|
237 |
"height": height,
|