jbilcke-hf HF staff commited on
Commit
d35cde0
1 Parent(s): ef15707

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -16
handler.py CHANGED
@@ -28,12 +28,18 @@ class EndpointHandler:
28
  MAX_HEIGHT = 720
29
  MAX_FRAMES = 257
30
 
 
 
 
31
  def __init__(self, path: str = ""):
32
  """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
33
 
34
  Args:
35
  path (str): Path to the model weights directory
36
  """
 
 
 
37
  # Load both pipelines with bfloat16 precision as recommended in docs
38
  self.text_to_video = LTXPipeline.from_pretrained(
39
  path,
@@ -45,9 +51,9 @@ class EndpointHandler:
45
  torch_dtype=torch.bfloat16
46
  ).to("cuda")
47
 
48
- # Enable memory optimizations
49
- self.text_to_video.enable_model_cpu_offload()
50
- self.image_to_video.enable_model_cpu_offload()
51
 
52
  def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
53
  """Validate and adjust resolution to meet constraints.
@@ -158,28 +164,30 @@ class EndpointHandler:
158
 
159
  Returns:
160
  Dict[str, Any]: Dictionary containing:
161
- - video: Base64 encoded MP4 video
162
- - content-type: MIME type of the video
163
  - metadata: Dictionary with actual values used for generation
164
  """
165
- # Extract and validate prompt
166
- prompt = data.get("prompt")
167
  if not prompt:
168
- raise ValueError("'prompt' is required in the input data")
169
 
170
  # Get and validate resolution
171
- width = data.get("width", self.DEFAULT_WIDTH)
172
- height = data.get("height", self.DEFAULT_HEIGHT)
173
  width, height = self._validate_and_adjust_resolution(width, height)
174
 
175
  # Get and validate frames and FPS
176
- num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
177
- fps = data.get("fps", self.DEFAULT_FPS)
178
  num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
179
 
180
  # Get other parameters with defaults
181
- guidance_scale = data.get("guidance_scale", 7.5)
182
- num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
 
 
183
 
184
  logger.info(f"Generating video with prompt: '{prompt}'")
185
  logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}")
@@ -216,9 +224,14 @@ class EndpointHandler:
216
  # Encode video to base64
217
  video_base64 = base64.b64encode(video_content).decode('utf-8')
218
 
 
 
 
 
 
219
  return {
220
- "video": video_base64,
221
- "content-type": "video/mp4",
222
  "metadata": {
223
  "width": width,
224
  "height": height,
 
28
  MAX_HEIGHT = 720
29
  MAX_FRAMES = 257
30
 
31
+ ENABLE_CPU_OFFLOAD = True
32
+ EXPERIMENTAL_STUFF = False
33
+
34
  def __init__(self, path: str = ""):
35
  """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
36
 
37
  Args:
38
  path (str): Path to the model weights directory
39
  """
40
+ if EXPERIMENTAL_STUFF:
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+
43
  # Load both pipelines with bfloat16 precision as recommended in docs
44
  self.text_to_video = LTXPipeline.from_pretrained(
45
  path,
 
51
  torch_dtype=torch.bfloat16
52
  ).to("cuda")
53
 
54
+ if ENABLE_CPU_OFFLOAD:
55
+ self.text_to_video.enable_model_cpu_offload()
56
+ self.image_to_video.enable_model_cpu_offload()
57
 
58
  def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
59
  """Validate and adjust resolution to meet constraints.
 
164
 
165
  Returns:
166
  Dict[str, Any]: Dictionary containing:
167
+ - video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:").
168
+ - content-type: MIME type of the video (right now always "video/mp4")
169
  - metadata: Dictionary with actual values used for generation
170
  """
171
+ # Get inputs from request data
172
+ prompt = data.pop("inputs", None)
173
  if not prompt:
174
+ raise ValueError("No prompt provided in the 'inputs' field")
175
 
176
  # Get and validate resolution
177
+ width = data.pop("width", self.DEFAULT_WIDTH)
178
+ height = data.pop("height", self.DEFAULT_HEIGHT)
179
  width, height = self._validate_and_adjust_resolution(width, height)
180
 
181
  # Get and validate frames and FPS
182
+ num_frames = data.pop("num_frames", self.DEFAULT_NUM_FRAMES)
183
+ fps = data.pop("fps", self.DEFAULT_FPS)
184
  num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
185
 
186
  # Get other parameters with defaults
187
+ guidance_scale = data.pop("guidance_scale", 7.5)
188
+ num_inference_steps = data.pop("num_inference_steps", self.DEFAULT_NUM_STEPS)
189
+ seed = data.pop("seed", -1)
190
+ seed = None if seed == -1 else int(seed)
191
 
192
  logger.info(f"Generating video with prompt: '{prompt}'")
193
  logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}")
 
224
  # Encode video to base64
225
  video_base64 = base64.b64encode(video_content).decode('utf-8')
226
 
227
+ content_type = "video/mp4"
228
+
229
+ # Add MP4 data URI prefix
230
+ video_data_uri = f"data:{content_type};base64,{video_base64}"
231
+
232
  return {
233
+ "video": video_data_uri,
234
+ "content-type": content_type,
235
  "metadata": {
236
  "width": width,
237
  "height": height,