Perghect commited on
Commit
d0db5dd
·
verified ·
1 Parent(s): 05ac8a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -19,6 +19,11 @@ from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Auto
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  torch.set_float32_matmul_precision('high')
21
 
 
 
 
 
 
22
  def load_image_from_link(url: str) -> Image.Image:
23
  """Downloads an image from a URL and returns a Pillow Image."""
24
  response = requests.get(url)
@@ -29,10 +34,6 @@ def load_image_from_link(url: str) -> Image.Image:
29
  # Gaussian Blur Functions
30
  def run_rmbg(image: Image.Image, threshold=0.5):
31
  """Runs the RMBG-2.0 model on the image and returns a binary mask."""
32
- model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
33
- model.to(device)
34
- model.eval()
35
-
36
  image_size = (1024, 1024)
37
  transform_image = transforms.Compose([
38
  transforms.Resize(image_size),
@@ -43,7 +44,7 @@ def run_rmbg(image: Image.Image, threshold=0.5):
43
  input_images = transform_image(image).unsqueeze(0).to(device)
44
 
45
  with torch.no_grad():
46
- preds = model(input_images)
47
  if isinstance(preds, list):
48
  mask_logits = preds[-1]
49
  else:
@@ -72,15 +73,11 @@ def apply_background_blur(image: Image.Image, mask: np.ndarray, sigma: float = 1
72
  # Lens Blur Functions
73
  def run_depth_estimation(image: Image.Image, target_size=(512, 512)):
74
  """Runs the Depth-Anything-V2-Small model and returns the depth map."""
75
- image_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
76
- model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
77
- model.to(device)
78
-
79
  image_resized = image.resize(target_size, resample=Image.BILINEAR)
80
- inputs = image_processor(images=image_resized, return_tensors="pt").to(device)
81
 
82
  with torch.no_grad():
83
- outputs = model(**inputs)
84
  predicted_depth = outputs.predicted_depth
85
 
86
  prediction = torch.nn.functional.interpolate(
 
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  torch.set_float32_matmul_precision('high')
21
 
22
+ # Load models at startup
23
+ rmbg_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device).eval()
24
+ depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
25
+ depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf").to(device)
26
+
27
  def load_image_from_link(url: str) -> Image.Image:
28
  """Downloads an image from a URL and returns a Pillow Image."""
29
  response = requests.get(url)
 
34
  # Gaussian Blur Functions
35
  def run_rmbg(image: Image.Image, threshold=0.5):
36
  """Runs the RMBG-2.0 model on the image and returns a binary mask."""
 
 
 
 
37
  image_size = (1024, 1024)
38
  transform_image = transforms.Compose([
39
  transforms.Resize(image_size),
 
44
  input_images = transform_image(image).unsqueeze(0).to(device)
45
 
46
  with torch.no_grad():
47
+ preds = rmbg_model(input_images)
48
  if isinstance(preds, list):
49
  mask_logits = preds[-1]
50
  else:
 
73
  # Lens Blur Functions
74
  def run_depth_estimation(image: Image.Image, target_size=(512, 512)):
75
  """Runs the Depth-Anything-V2-Small model and returns the depth map."""
 
 
 
 
76
  image_resized = image.resize(target_size, resample=Image.BILINEAR)
77
+ inputs = depth_processor(images=image_resized, return_tensors="pt").to(device)
78
 
79
  with torch.no_grad():
80
+ outputs = depth_model(**inputs)
81
  predicted_depth = outputs.predicted_depth
82
 
83
  prediction = torch.nn.functional.interpolate(