Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,11 @@ from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Auto
|
|
19 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
torch.set_float32_matmul_precision('high')
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
def load_image_from_link(url: str) -> Image.Image:
|
23 |
"""Downloads an image from a URL and returns a Pillow Image."""
|
24 |
response = requests.get(url)
|
@@ -29,10 +34,6 @@ def load_image_from_link(url: str) -> Image.Image:
|
|
29 |
# Gaussian Blur Functions
|
30 |
def run_rmbg(image: Image.Image, threshold=0.5):
|
31 |
"""Runs the RMBG-2.0 model on the image and returns a binary mask."""
|
32 |
-
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
|
33 |
-
model.to(device)
|
34 |
-
model.eval()
|
35 |
-
|
36 |
image_size = (1024, 1024)
|
37 |
transform_image = transforms.Compose([
|
38 |
transforms.Resize(image_size),
|
@@ -43,7 +44,7 @@ def run_rmbg(image: Image.Image, threshold=0.5):
|
|
43 |
input_images = transform_image(image).unsqueeze(0).to(device)
|
44 |
|
45 |
with torch.no_grad():
|
46 |
-
preds =
|
47 |
if isinstance(preds, list):
|
48 |
mask_logits = preds[-1]
|
49 |
else:
|
@@ -72,15 +73,11 @@ def apply_background_blur(image: Image.Image, mask: np.ndarray, sigma: float = 1
|
|
72 |
# Lens Blur Functions
|
73 |
def run_depth_estimation(image: Image.Image, target_size=(512, 512)):
|
74 |
"""Runs the Depth-Anything-V2-Small model and returns the depth map."""
|
75 |
-
image_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
|
76 |
-
model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
|
77 |
-
model.to(device)
|
78 |
-
|
79 |
image_resized = image.resize(target_size, resample=Image.BILINEAR)
|
80 |
-
inputs =
|
81 |
|
82 |
with torch.no_grad():
|
83 |
-
outputs =
|
84 |
predicted_depth = outputs.predicted_depth
|
85 |
|
86 |
prediction = torch.nn.functional.interpolate(
|
|
|
19 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
torch.set_float32_matmul_precision('high')
|
21 |
|
22 |
+
# Load models at startup
|
23 |
+
rmbg_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device).eval()
|
24 |
+
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
|
25 |
+
depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf").to(device)
|
26 |
+
|
27 |
def load_image_from_link(url: str) -> Image.Image:
|
28 |
"""Downloads an image from a URL and returns a Pillow Image."""
|
29 |
response = requests.get(url)
|
|
|
34 |
# Gaussian Blur Functions
|
35 |
def run_rmbg(image: Image.Image, threshold=0.5):
|
36 |
"""Runs the RMBG-2.0 model on the image and returns a binary mask."""
|
|
|
|
|
|
|
|
|
37 |
image_size = (1024, 1024)
|
38 |
transform_image = transforms.Compose([
|
39 |
transforms.Resize(image_size),
|
|
|
44 |
input_images = transform_image(image).unsqueeze(0).to(device)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
+
preds = rmbg_model(input_images)
|
48 |
if isinstance(preds, list):
|
49 |
mask_logits = preds[-1]
|
50 |
else:
|
|
|
73 |
# Lens Blur Functions
|
74 |
def run_depth_estimation(image: Image.Image, target_size=(512, 512)):
|
75 |
"""Runs the Depth-Anything-V2-Small model and returns the depth map."""
|
|
|
|
|
|
|
|
|
76 |
image_resized = image.resize(target_size, resample=Image.BILINEAR)
|
77 |
+
inputs = depth_processor(images=image_resized, return_tensors="pt").to(device)
|
78 |
|
79 |
with torch.no_grad():
|
80 |
+
outputs = depth_model(**inputs)
|
81 |
predicted_depth = outputs.predicted_depth
|
82 |
|
83 |
prediction = torch.nn.functional.interpolate(
|