Update handler.py
Browse files- handler.py +24 -7
handler.py
CHANGED
@@ -8,6 +8,11 @@ import tempfile
|
|
8 |
import numpy as np
|
9 |
from moviepy.editor import ImageSequenceClip
|
10 |
import os
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, path: str = ""):
|
@@ -34,22 +39,28 @@ class EndpointHandler:
|
|
34 |
# Set default FPS
|
35 |
self.fps = 24
|
36 |
|
37 |
-
def _create_video_file(self,
|
38 |
"""Convert frames to an MP4 video file.
|
39 |
|
40 |
Args:
|
41 |
-
|
42 |
fps (int): Frames per second for the output video
|
43 |
|
44 |
Returns:
|
45 |
bytes: MP4 video file content
|
46 |
"""
|
47 |
-
#
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
video_np = (video_np * 255).astype(np.uint8)
|
50 |
|
51 |
# Get dimensions
|
52 |
-
height, width = video_np.shape
|
|
|
53 |
|
54 |
# Create temporary file
|
55 |
output_path = tempfile.mktemp(suffix=".mp4")
|
@@ -103,6 +114,9 @@ class EndpointHandler:
|
|
103 |
guidance_scale = data.get("guidance_scale", 7.5)
|
104 |
num_inference_steps = data.get("num_inference_steps", 50)
|
105 |
|
|
|
|
|
|
|
106 |
# Check if image is provided for image-to-video generation
|
107 |
image_data = data.get("image")
|
108 |
|
@@ -112,6 +126,7 @@ class EndpointHandler:
|
|
112 |
# Decode base64 image
|
113 |
image_bytes = base64.b64decode(image_data)
|
114 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
115 |
|
116 |
# Generate video from image
|
117 |
output = self.image_to_video(
|
@@ -121,8 +136,9 @@ class EndpointHandler:
|
|
121 |
guidance_scale=guidance_scale,
|
122 |
num_inference_steps=num_inference_steps,
|
123 |
output_type="pt"
|
124 |
-
).frames[0]
|
125 |
else:
|
|
|
126 |
# Generate video from text only
|
127 |
output = self.text_to_video(
|
128 |
prompt=prompt,
|
@@ -130,7 +146,7 @@ class EndpointHandler:
|
|
130 |
guidance_scale=guidance_scale,
|
131 |
num_inference_steps=num_inference_steps,
|
132 |
output_type="pt"
|
133 |
-
).frames[0]
|
134 |
|
135 |
# Convert frames to video file
|
136 |
video_content = self._create_video_file(output, fps=fps)
|
@@ -144,4 +160,5 @@ class EndpointHandler:
|
|
144 |
}
|
145 |
|
146 |
except Exception as e:
|
|
|
147 |
raise RuntimeError(f"Error generating video: {str(e)}")
|
|
|
8 |
import numpy as np
|
9 |
from moviepy.editor import ImageSequenceClip
|
10 |
import os
|
11 |
+
import logging
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
|
17 |
class EndpointHandler:
|
18 |
def __init__(self, path: str = ""):
|
|
|
39 |
# Set default FPS
|
40 |
self.fps = 24
|
41 |
|
42 |
+
def _create_video_file(self, frames: torch.Tensor, fps: int = 24) -> bytes:
|
43 |
"""Convert frames to an MP4 video file.
|
44 |
|
45 |
Args:
|
46 |
+
frames (torch.Tensor): Generated frames tensor
|
47 |
fps (int): Frames per second for the output video
|
48 |
|
49 |
Returns:
|
50 |
bytes: MP4 video file content
|
51 |
"""
|
52 |
+
# Log frame information
|
53 |
+
num_frames = frames.shape[1] # Shape should be [1, num_frames, channels, height, width]
|
54 |
+
duration = num_frames / fps
|
55 |
+
logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")
|
56 |
+
|
57 |
+
# Convert tensor to numpy array - remove batch dimension and rearrange to [num_frames, height, width, channels]
|
58 |
+
video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
|
59 |
video_np = (video_np * 255).astype(np.uint8)
|
60 |
|
61 |
# Get dimensions
|
62 |
+
_, height, width, _ = video_np.shape
|
63 |
+
logger.info(f"Video dimensions: {width}x{height}")
|
64 |
|
65 |
# Create temporary file
|
66 |
output_path = tempfile.mktemp(suffix=".mp4")
|
|
|
114 |
guidance_scale = data.get("guidance_scale", 7.5)
|
115 |
num_inference_steps = data.get("num_inference_steps", 50)
|
116 |
|
117 |
+
logger.info(f"Generating video with prompt: '{prompt}'")
|
118 |
+
logger.info(f"Parameters: num_frames={num_frames}, fps={fps}, guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")
|
119 |
+
|
120 |
# Check if image is provided for image-to-video generation
|
121 |
image_data = data.get("image")
|
122 |
|
|
|
126 |
# Decode base64 image
|
127 |
image_bytes = base64.b64decode(image_data)
|
128 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
129 |
+
logger.info("Using image-to-video generation mode")
|
130 |
|
131 |
# Generate video from image
|
132 |
output = self.image_to_video(
|
|
|
136 |
guidance_scale=guidance_scale,
|
137 |
num_inference_steps=num_inference_steps,
|
138 |
output_type="pt"
|
139 |
+
).frames # Remove [0] to keep all frames
|
140 |
else:
|
141 |
+
logger.info("Using text-to-video generation mode")
|
142 |
# Generate video from text only
|
143 |
output = self.text_to_video(
|
144 |
prompt=prompt,
|
|
|
146 |
guidance_scale=guidance_scale,
|
147 |
num_inference_steps=num_inference_steps,
|
148 |
output_type="pt"
|
149 |
+
).frames # Remove [0] to keep all frames
|
150 |
|
151 |
# Convert frames to video file
|
152 |
video_content = self._create_video_file(output, fps=fps)
|
|
|
160 |
}
|
161 |
|
162 |
except Exception as e:
|
163 |
+
logger.error(f"Error generating video: {str(e)}")
|
164 |
raise RuntimeError(f"Error generating video: {str(e)}")
|