multimodalart HF staff commited on
Commit
b832af5
1 Parent(s): 702754c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -15,6 +15,27 @@ MODEL_REPO = "rain1011/pyramid-flow-sd3"
15
  MODEL_VARIANT = "diffusion_transformer_768p"
16
  MODEL_DTYPE = "bf16"
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Download and load the model
19
  def load_model():
20
  if not os.path.exists(MODEL_PATH):
@@ -67,12 +88,13 @@ def generate_video_from_image(image, prompt, duration, video_guidance_scale):
67
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
68
 
69
  target_size = (1280, 720)
70
- image = ImageOps.fit(image, target_size, method=Image.LANCZOS, centering=(0.5, 0.5))
 
71
 
72
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
73
  frames = model.generate_i2v(
74
  prompt=prompt,
75
- input_image=image,
76
  num_inference_steps=[10, 10, 10],
77
  temp=temp,
78
  guidance_scale=7.0,
 
15
  MODEL_VARIANT = "diffusion_transformer_768p"
16
  MODEL_DTYPE = "bf16"
17
 
18
+ def center_crop(image, target_width, target_height):
19
+ width, height = image.size
20
+ aspect_ratio_target = target_width / target_height
21
+ aspect_ratio_image = width / height
22
+
23
+ if aspect_ratio_image > aspect_ratio_target:
24
+ # Crop the width (left and right)
25
+ new_width = int(height * aspect_ratio_target)
26
+ left = (width - new_width) // 2
27
+ right = left + new_width
28
+ top, bottom = 0, height
29
+ else:
30
+ # Crop the height (top and bottom)
31
+ new_height = int(width / aspect_ratio_target)
32
+ top = (height - new_height) // 2
33
+ bottom = top + new_height
34
+ left, right = 0, width
35
+
36
+ image = image.crop((left, top, right, bottom))
37
+ return image
38
+
39
  # Download and load the model
40
  def load_model():
41
  if not os.path.exists(MODEL_PATH):
 
88
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
89
 
90
  target_size = (1280, 720)
91
+ cropped_image = center_crop(image, 1280, 720)
92
+ resized_image = cropped_image.resize((1280, 720))
93
 
94
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
95
  frames = model.generate_i2v(
96
  prompt=prompt,
97
+ input_image=resized_image,
98
  num_inference_steps=[10, 10, 10],
99
  temp=temp,
100
  guidance_scale=7.0,