Spaces:
Running
on
Zero
Running
on
Zero
update gpu
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ from einops import rearrange
|
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from PIL import Image
|
6 |
-
import torchvision.transforms as transforms
|
7 |
from torch import nn
|
8 |
import numpy as np
|
9 |
import os
|
@@ -17,6 +16,15 @@ USE_CUDA = torch.cuda.is_available()
|
|
17 |
|
18 |
print("CUDA is available:", USE_CUDA)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class MobileSAM(nn.Module):
|
21 |
def __init__(self, **kwargs):
|
22 |
super().__init__(**kwargs)
|
@@ -139,19 +147,12 @@ mobilesam = MobileSAM()
|
|
139 |
|
140 |
def image_mobilesam_feature(
|
141 |
images,
|
142 |
-
resolution=(1024, 1024),
|
143 |
node_type="block",
|
144 |
layer=-1,
|
145 |
):
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
transforms.Resize(resolution),
|
150 |
-
transforms.ToTensor(),
|
151 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
152 |
-
]
|
153 |
-
)
|
154 |
-
|
155 |
|
156 |
feat_extractor = mobilesam
|
157 |
if USE_CUDA:
|
@@ -159,12 +160,9 @@ def image_mobilesam_feature(
|
|
159 |
|
160 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
161 |
outputs = []
|
162 |
-
for i
|
163 |
-
torch_image = transform(image)
|
164 |
-
if USE_CUDA:
|
165 |
-
torch_image = torch_image.cuda()
|
166 |
attn_output, mlp_output, block_output = feat_extractor(
|
167 |
-
|
168 |
)
|
169 |
out_dict = {
|
170 |
"attn": attn_output,
|
@@ -251,18 +249,12 @@ sam = SAM()
|
|
251 |
|
252 |
def image_sam_feature(
|
253 |
images,
|
254 |
-
resolution=(1024, 1024),
|
255 |
node_type="block",
|
256 |
layer=-1,
|
257 |
):
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
transforms.Resize(resolution),
|
262 |
-
transforms.ToTensor(),
|
263 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
264 |
-
]
|
265 |
-
)
|
266 |
|
267 |
feat_extractor = sam
|
268 |
if USE_CUDA:
|
@@ -270,12 +262,9 @@ def image_sam_feature(
|
|
270 |
|
271 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
272 |
outputs = []
|
273 |
-
for i
|
274 |
-
torch_image = transform(image)
|
275 |
-
if USE_CUDA:
|
276 |
-
torch_image = torch_image.cuda()
|
277 |
attn_output, mlp_output, block_output = feat_extractor(
|
278 |
-
|
279 |
)
|
280 |
out_dict = {
|
281 |
"attn": attn_output,
|
@@ -338,27 +327,20 @@ class DiNOv2(torch.nn.Module):
|
|
338 |
|
339 |
dinov2 = DiNOv2()
|
340 |
|
341 |
-
def image_dino_feature(images,
|
342 |
|
343 |
-
|
344 |
-
|
345 |
-
transforms.Resize(resolution),
|
346 |
-
transforms.ToTensor(),
|
347 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
348 |
-
]
|
349 |
-
)
|
350 |
|
351 |
feat_extractor = dinov2
|
352 |
if USE_CUDA:
|
353 |
feat_extractor = feat_extractor.cuda()
|
354 |
|
|
|
355 |
outputs = []
|
356 |
-
for i
|
357 |
-
torch_image = transform(image)
|
358 |
-
if USE_CUDA:
|
359 |
-
torch_image = torch_image.cuda()
|
360 |
attn_output, mlp_output, block_output = feat_extractor(
|
361 |
-
|
362 |
)
|
363 |
out_dict = {
|
364 |
"attn": attn_output,
|
@@ -443,33 +425,20 @@ class CLIP(torch.nn.Module):
|
|
443 |
clip = CLIP()
|
444 |
|
445 |
def image_clip_feature(
|
446 |
-
images,
|
447 |
):
|
448 |
-
if
|
449 |
-
|
450 |
-
else:
|
451 |
-
assert isinstance(images, Image.Image), "Input must be a PIL image."
|
452 |
-
images = [images]
|
453 |
-
|
454 |
-
transform = transforms.Compose(
|
455 |
-
[
|
456 |
-
transforms.Resize(resolution),
|
457 |
-
transforms.ToTensor(),
|
458 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
459 |
-
]
|
460 |
-
)
|
461 |
|
462 |
feat_extractor = clip
|
463 |
if USE_CUDA:
|
464 |
feat_extractor = feat_extractor.cuda()
|
465 |
|
|
|
466 |
outputs = []
|
467 |
-
for i
|
468 |
-
torch_image = transform(image)
|
469 |
-
if USE_CUDA:
|
470 |
-
torch_image = torch_image.cuda()
|
471 |
attn_output, mlp_output, block_output = feat_extractor(
|
472 |
-
|
473 |
)
|
474 |
out_dict = {
|
475 |
"attn": attn_output,
|
@@ -527,27 +496,35 @@ def compute_hash(*args, **kwargs):
|
|
527 |
|
528 |
|
529 |
@spaces.GPU(duration=30)
|
530 |
-
def run_model_on_image(
|
531 |
global USE_CUDA
|
532 |
USE_CUDA = True
|
533 |
|
534 |
if model_name == "SAM(sam_vit_b)":
|
535 |
if not USE_CUDA:
|
536 |
gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
|
537 |
-
result = image_sam_feature(
|
538 |
elif model_name == 'MobileSAM':
|
539 |
-
result = image_mobilesam_feature(
|
540 |
elif model_name == "DiNO(dinov2_vitb14_reg)":
|
541 |
-
result = image_dino_feature(
|
542 |
elif model_name == "CLIP(openai/clip-vit-base-patch16)":
|
543 |
-
result = image_clip_feature(
|
544 |
else:
|
545 |
raise ValueError(f"Model {model_name} not supported.")
|
546 |
|
547 |
USE_CUDA = False
|
548 |
return result
|
549 |
|
550 |
-
def extract_features(images, model_name="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
# Compute the cache key
|
552 |
cache_key = compute_hash(images, model_name, node_type, layer)
|
553 |
|
@@ -556,7 +533,7 @@ def extract_features(images, model_name="sam", node_type="block", layer=-1):
|
|
556 |
print("Cache hit!")
|
557 |
return cache[cache_key]
|
558 |
|
559 |
-
result = run_model_on_image(images
|
560 |
|
561 |
# Store the result in the cache
|
562 |
cache[cache_key] = result
|
|
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from PIL import Image
|
|
|
6 |
from torch import nn
|
7 |
import numpy as np
|
8 |
import os
|
|
|
16 |
|
17 |
print("CUDA is available:", USE_CUDA)
|
18 |
|
19 |
+
def transform_images(images, resolution=(1024, 1024)):
|
20 |
+
images = [image.convert("RGB").resize(resolution) for image in images]
|
21 |
+
# Convert to torch tensor
|
22 |
+
images = [torch.tensor(np.array(image).transpose(2, 0, 1)).float() / 255 for image in images]
|
23 |
+
# Normalize
|
24 |
+
images = [(image - 0.5) / 0.5 for image in images]
|
25 |
+
images = torch.stack(images)
|
26 |
+
return images
|
27 |
+
|
28 |
class MobileSAM(nn.Module):
|
29 |
def __init__(self, **kwargs):
|
30 |
super().__init__(**kwargs)
|
|
|
147 |
|
148 |
def image_mobilesam_feature(
|
149 |
images,
|
|
|
150 |
node_type="block",
|
151 |
layer=-1,
|
152 |
):
|
153 |
|
154 |
+
if USE_CUDA:
|
155 |
+
images = images.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
feat_extractor = mobilesam
|
158 |
if USE_CUDA:
|
|
|
160 |
|
161 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
162 |
outputs = []
|
163 |
+
for i in range(images.shape[0]):
|
|
|
|
|
|
|
164 |
attn_output, mlp_output, block_output = feat_extractor(
|
165 |
+
images[i].unsqueeze(0)
|
166 |
)
|
167 |
out_dict = {
|
168 |
"attn": attn_output,
|
|
|
249 |
|
250 |
def image_sam_feature(
|
251 |
images,
|
|
|
252 |
node_type="block",
|
253 |
layer=-1,
|
254 |
):
|
255 |
|
256 |
+
if USE_CUDA:
|
257 |
+
images = images.cuda()
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
feat_extractor = sam
|
260 |
if USE_CUDA:
|
|
|
262 |
|
263 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
264 |
outputs = []
|
265 |
+
for i in range(images.shape[0]):
|
|
|
|
|
|
|
266 |
attn_output, mlp_output, block_output = feat_extractor(
|
267 |
+
images[i].unsqueeze(0)
|
268 |
)
|
269 |
out_dict = {
|
270 |
"attn": attn_output,
|
|
|
327 |
|
328 |
dinov2 = DiNOv2()
|
329 |
|
330 |
+
def image_dino_feature(images, node_type="block", layer=-1):
|
331 |
|
332 |
+
if USE_CUDA:
|
333 |
+
images = images.cuda()
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
feat_extractor = dinov2
|
336 |
if USE_CUDA:
|
337 |
feat_extractor = feat_extractor.cuda()
|
338 |
|
339 |
+
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
340 |
outputs = []
|
341 |
+
for i in range(images.shape[0]):
|
|
|
|
|
|
|
342 |
attn_output, mlp_output, block_output = feat_extractor(
|
343 |
+
images[i].unsqueeze(0)
|
344 |
)
|
345 |
out_dict = {
|
346 |
"attn": attn_output,
|
|
|
425 |
clip = CLIP()
|
426 |
|
427 |
def image_clip_feature(
|
428 |
+
images, node_type="block", layer=-1
|
429 |
):
|
430 |
+
if USE_CUDA:
|
431 |
+
images = images.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
feat_extractor = clip
|
434 |
if USE_CUDA:
|
435 |
feat_extractor = feat_extractor.cuda()
|
436 |
|
437 |
+
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
438 |
outputs = []
|
439 |
+
for i in range(images.shape[0]):
|
|
|
|
|
|
|
440 |
attn_output, mlp_output, block_output = feat_extractor(
|
441 |
+
images[i].unsqueeze(0)
|
442 |
)
|
443 |
out_dict = {
|
444 |
"attn": attn_output,
|
|
|
496 |
|
497 |
|
498 |
@spaces.GPU(duration=30)
|
499 |
+
def run_model_on_image(images, model_name="sam", node_type="block", layer=-1):
|
500 |
global USE_CUDA
|
501 |
USE_CUDA = True
|
502 |
|
503 |
if model_name == "SAM(sam_vit_b)":
|
504 |
if not USE_CUDA:
|
505 |
gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
|
506 |
+
result = image_sam_feature(images, node_type=node_type, layer=layer)
|
507 |
elif model_name == 'MobileSAM':
|
508 |
+
result = image_mobilesam_feature(images, node_type=node_type, layer=layer)
|
509 |
elif model_name == "DiNO(dinov2_vitb14_reg)":
|
510 |
+
result = image_dino_feature(images, node_type=node_type, layer=layer)
|
511 |
elif model_name == "CLIP(openai/clip-vit-base-patch16)":
|
512 |
+
result = image_clip_feature(images, node_type=node_type, layer=layer)
|
513 |
else:
|
514 |
raise ValueError(f"Model {model_name} not supported.")
|
515 |
|
516 |
USE_CUDA = False
|
517 |
return result
|
518 |
|
519 |
+
def extract_features(images, model_name="mobilesam", node_type="block", layer=-1):
|
520 |
+
resolution_dict = {
|
521 |
+
"mobilesam": (1024, 1024),
|
522 |
+
"sam(sam_vit_b)": (1024, 1024),
|
523 |
+
"dinov2(dinov2_vitb14_reg)": (448, 448),
|
524 |
+
"clip(openai/clip-vit-base-patch16)": (224, 224),
|
525 |
+
}
|
526 |
+
images = transform_images(images, resolution=resolution_dict[model_name])
|
527 |
+
|
528 |
# Compute the cache key
|
529 |
cache_key = compute_hash(images, model_name, node_type, layer)
|
530 |
|
|
|
533 |
print("Cache hit!")
|
534 |
return cache[cache_key]
|
535 |
|
536 |
+
result = run_model_on_image(images, model_name=model_name, node_type=node_type, layer=layer)
|
537 |
|
538 |
# Store the result in the cache
|
539 |
cache[cache_key] = result
|