jbilcke-hf HF staff commited on
Commit
f6dd4f3
1 Parent(s): 85f39ae

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +73 -84
handler.py CHANGED
@@ -70,6 +70,14 @@ class EndpointHandler:
70
  self.text_to_video.enable_model_cpu_offload()
71
  self.image_to_video.enable_model_cpu_offload()
72
 
 
 
 
 
 
 
 
 
73
  def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
74
  """Validate and adjust resolution to meet constraints.
75
 
@@ -117,57 +125,44 @@ class EndpointHandler:
117
 
118
  return num_frames, fps
119
 
120
- def _create_video_file(self, frames: torch.Tensor, fps: int = DEFAULT_FPS) -> bytes:
121
- """Convert frames to an MP4 video file.
 
 
 
 
 
 
 
122
 
123
- Args:
124
- frames (torch.Tensor): Generated frames tensor
125
- fps (int): Frames per second for the output video
126
-
127
- Returns:
128
- bytes: MP4 video file content
129
- """
130
- # Log frame information
131
- num_frames = frames.shape[1]
132
- duration = num_frames / fps
133
- logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")
134
-
135
- # Convert tensor to numpy array
136
- video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
137
- video_np = (video_np * 255).astype(np.uint8)
138
 
139
- # Get dimensions
140
- _, height, width, _ = video_np.shape
141
- logger.info(f"Video dimensions: {width}x{height}")
 
 
 
 
142
 
143
- # Create temporary file
144
- output_path = tempfile.mktemp(suffix=".mp4")
145
-
146
- try:
147
- # Create video clip and write to file
148
- clip = ImageSequenceClip(list(video_np), fps=fps)
149
-
150
- # potential speed optimizations:
151
- # there is a preset= field, to trade encoding speed with file size (but not quality)
152
- # values are: ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow, placebo
153
- #
154
- # there is a threads= field, by default None, which can be set to 2, 3, 4 etc..
155
- clip.write_videofile(output_path, codec="libx264", audio=False)
156
-
157
- # Read the video file
158
- with open(output_path, "rb") as f:
159
- video_content = f.read()
160
-
161
- return video_content
162
-
163
- finally:
164
- # Cleanup
165
- if os.path.exists(output_path):
166
- os.remove(output_path)
167
-
168
- # Clear memory
169
- del video_np
170
- torch.cuda.empty_cache()
171
 
172
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
173
  """Process the input data and generate video using LTX.
@@ -189,35 +184,32 @@ class EndpointHandler:
189
  - content-type: MIME type of the video (right now always "video/mp4")
190
  - metadata: Dictionary with actual values used for generation
191
  """
192
- # Get inputs from request data
193
  prompt = data.get("inputs", None)
194
  if not prompt:
195
  raise ValueError("No prompt provided in the 'inputs' field")
196
 
197
- # Get and validate resolution
198
  width = data.get("width", self.DEFAULT_WIDTH)
199
  height = data.get("height", self.DEFAULT_HEIGHT)
200
  width, height = self._validate_and_adjust_resolution(width, height)
201
-
202
- # Get and validate frames and FPS
203
  num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
204
  fps = data.get("fps", self.DEFAULT_FPS)
205
  num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
206
-
207
- # Get other parameters with defaults
 
 
 
 
208
  guidance_scale = data.get("guidance_scale", 7.5)
209
  num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
210
-
211
  seed = data.get("seed", -1)
212
  seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)
213
-
214
- logger.info(f"Generating video with prompt: '{prompt}'")
215
- logger.info(f"Video params: size={width}x{height}, num_frames={num_frames}, fps={fps}")
216
- logger.info(f"Generation params: seed={seed}, guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")
217
 
218
  try:
219
  with torch.no_grad():
220
-
221
  random.seed(seed)
222
  np.random.seed(seed)
223
  generator.manual_seed(seed)
@@ -233,43 +225,40 @@ class EndpointHandler:
233
  "generator": generator
234
  }
235
 
236
- # Check if image is provided for image-to-video generation
237
  image_data = data.get("image")
238
  if image_data:
239
  if image_data.startswith('data:'):
240
  image_data = image_data.split(',', 1)[1]
241
  image_bytes = base64.b64decode(image_data)
242
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
243
- logger.info("Using image-to-video generation mode")
244
  generation_kwargs["image"] = image
245
- output = self.image_to_video(**generation_kwargs).frames
246
  else:
247
- logger.info("Using text-to-video generation mode")
248
- output = self.text_to_video(**generation_kwargs).frames
249
 
250
- # Convert frames to video file
251
- video_content = self._create_video_file(output, fps=fps)
 
 
 
 
 
 
252
 
253
- # Encode video to base64
254
- video_base64 = base64.b64encode(video_content).decode('utf-8')
255
-
256
- content_type = "video/mp4"
 
 
 
 
257
 
258
- # Add MP4 data URI prefix
259
- video_data_uri = f"data:{content_type};base64,{video_base64}"
260
-
261
  return {
262
  "video": video_data_uri,
263
- "content-type": content_type,
264
- "metadata": {
265
- "width": width,
266
- "height": height,
267
- "num_frames": num_frames,
268
- "fps": fps,
269
- "duration": num_frames / fps,
270
- "num_inference_steps": num_inference_steps,
271
- "seed": seed
272
- }
273
  }
274
 
275
  except Exception as e:
 
70
  self.text_to_video.enable_model_cpu_offload()
71
  self.image_to_video.enable_model_cpu_offload()
72
 
73
+ self.varnish = Varnish(
74
+ device="cuda" if torch.cuda.is_available() else "cpu",
75
+ output_format="mp4",
76
+ output_codec="h264",
77
+ output_quality=23,
78
+ enable_mmaudio=False
79
+ )
80
+
81
  def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
82
  """Validate and adjust resolution to meet constraints.
83
 
 
125
 
126
  return num_frames, fps
127
 
128
+ async def process_and_encode_video(
129
+ self,
130
+ frames: torch.Tensor,
131
+ fps: int,
132
+ upscale_factor: int = 0,
133
+ enable_interpolation: bool = False,
134
+ interpolation_exp: int = 1
135
+ ) -> tuple[str, dict]:
136
+ """Process video frames using Varnish and return base64 encoded result"""
137
 
138
+ # Process video with Varnish
139
+ result = await self.varnish(
140
+ input_data=frames,
141
+ input_fps=fps,
142
+ output_fps=fps,
143
+ enable_upscale=upscale_factor > 1,
144
+ upscale_factor=upscale_factor,
145
+ enable_interpolation=enable_interpolation,
146
+ interpolation_exp=interpolation_exp
147
+ )
 
 
 
 
 
148
 
149
+ # Get video as data URI
150
+ video_data_uri = await result.write(
151
+ output_type="data-uri",
152
+ output_format="mp4",
153
+ output_codec="h264",
154
+ output_quality=23
155
+ )
156
 
157
+ metadata = {
158
+ "width": result.metadata.width,
159
+ "height": result.metadata.height,
160
+ "num_frames": result.metadata.frame_count,
161
+ "fps": result.metadata.fps,
162
+ "duration": result.metadata.duration
163
+ }
164
+
165
+ return video_data_uri, metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
168
  """Process the input data and generate video using LTX.
 
184
  - content-type: MIME type of the video (right now always "video/mp4")
185
  - metadata: Dictionary with actual values used for generation
186
  """
187
+
188
  prompt = data.get("inputs", None)
189
  if not prompt:
190
  raise ValueError("No prompt provided in the 'inputs' field")
191
 
192
+ # Get generation parameters
193
  width = data.get("width", self.DEFAULT_WIDTH)
194
  height = data.get("height", self.DEFAULT_HEIGHT)
195
  width, height = self._validate_and_adjust_resolution(width, height)
196
+
 
197
  num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
198
  fps = data.get("fps", self.DEFAULT_FPS)
199
  num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
200
+
201
+ # Get post-processing parameters
202
+ upscale_factor = data.get("upscale_factor", 0)
203
+ enable_interpolation = data.get("enable_interpolation", False)
204
+ interpolation_exp = data.get("interpolation_exp", 1)
205
+
206
  guidance_scale = data.get("guidance_scale", 7.5)
207
  num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
 
208
  seed = data.get("seed", -1)
209
  seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)
 
 
 
 
210
 
211
  try:
212
  with torch.no_grad():
 
213
  random.seed(seed)
214
  np.random.seed(seed)
215
  generator.manual_seed(seed)
 
225
  "generator": generator
226
  }
227
 
228
+ # Generate frames using appropriate pipeline
229
  image_data = data.get("image")
230
  if image_data:
231
  if image_data.startswith('data:'):
232
  image_data = image_data.split(',', 1)[1]
233
  image_bytes = base64.b64decode(image_data)
234
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
235
  generation_kwargs["image"] = image
236
+ frames = self.image_to_video(**generation_kwargs).frames
237
  else:
238
+ frames = self.text_to_video(**generation_kwargs).frames
 
239
 
240
+ # Process and encode video
241
+ video_data_uri, metadata = await self.process_and_encode_video(
242
+ frames=frames,
243
+ fps=fps,
244
+ upscale_factor=upscale_factor,
245
+ enable_interpolation=enable_interpolation,
246
+ interpolation_exp=interpolation_exp
247
+ )
248
 
249
+ # Add generation metadata
250
+ metadata.update({
251
+ "num_inference_steps": num_inference_steps,
252
+ "seed": seed,
253
+ "upscale_factor": upscale_factor,
254
+ "interpolation_enabled": enable_interpolation,
255
+ "interpolation_exp": interpolation_exp
256
+ })
257
 
 
 
 
258
  return {
259
  "video": video_data_uri,
260
+ "content-type": "video/mp4",
261
+ "metadata": metadata
 
 
 
 
 
 
 
 
262
  }
263
 
264
  except Exception as e: