huzey commited on
Commit
e3b132f
1 Parent(s): 07462e7

update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -10,6 +10,10 @@ import time
10
 
11
  import gradio as gr
12
 
 
 
 
 
13
 
14
  class SAM(torch.nn.Module):
15
  def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
@@ -50,7 +54,8 @@ class SAM(torch.nn.Module):
50
 
51
  self.image_encoder = sam.image_encoder
52
  self.image_encoder.eval()
53
- # self.image_encoder = self.image_encoder.cuda()
 
54
 
55
  @torch.no_grad()
56
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -98,8 +103,9 @@ def image_sam_feature(
98
  outputs = []
99
  for i, image in enumerate(images):
100
  torch_image = transform(image)
 
 
101
  attn_output, mlp_output, block_output = feat_extractor(
102
- # torch_image.unsqueeze(0).cuda()
103
  torch_image.unsqueeze(0)
104
  )
105
  out_dict = {
@@ -120,7 +126,8 @@ class DiNOv2(torch.nn.Module):
120
  self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
121
  self.dinov2.requires_grad_(False)
122
  self.dinov2.eval()
123
- # self.dinov2 = self.dinov2.cuda()
 
124
 
125
  def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
126
  def attn_residual_func(x):
@@ -173,8 +180,9 @@ def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-
173
  outputs = []
174
  for i, image in enumerate(images):
175
  torch_image = transform(image)
 
 
176
  attn_output, mlp_output, block_output = feat_extractor(
177
- # torch_image.unsqueeze(0).cuda()
178
  torch_image.unsqueeze(0)
179
  )
180
  out_dict = {
@@ -199,8 +207,9 @@ class CLIP(torch.nn.Module):
199
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
200
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
201
  self.model = model.eval()
202
- # self.model = self.model.cuda()
203
-
 
204
  def new_forward(
205
  self,
206
  hidden_states: torch.Tensor,
@@ -277,8 +286,9 @@ def image_clip_feature(
277
  outputs = []
278
  for i, image in enumerate(images):
279
  torch_image = transform(image)
 
 
280
  attn_output, mlp_output, block_output = feat_extractor(
281
- # torch_image.unsqueeze(0).cuda()
282
  torch_image.unsqueeze(0)
283
  )
284
  out_dict = {
@@ -321,7 +331,7 @@ def compute_ncut(
321
  eigvecs, eigvals = NCUT(
322
  num_eig=num_eig,
323
  num_sample=num_sample_ncut,
324
- # device="cuda:0",
325
  affinity_focal_gamma=affinity_focal_gamma,
326
  knn=knn_ncut,
327
  ).fit_transform(features.reshape(-1, features.shape[-1]))
 
10
 
11
  import gradio as gr
12
 
13
+ use_cuda = torch.cuda.is_available()
14
+
15
+ print("CUDA is available:", use_cuda)
16
+
17
 
18
  class SAM(torch.nn.Module):
19
  def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
 
54
 
55
  self.image_encoder = sam.image_encoder
56
  self.image_encoder.eval()
57
+ if use_cuda:
58
+ self.image_encoder = self.image_encoder.cuda()
59
 
60
  @torch.no_grad()
61
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
103
  outputs = []
104
  for i, image in enumerate(images):
105
  torch_image = transform(image)
106
+ if use_cuda:
107
+ torch_image = torch_image.cuda()
108
  attn_output, mlp_output, block_output = feat_extractor(
 
109
  torch_image.unsqueeze(0)
110
  )
111
  out_dict = {
 
126
  self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
127
  self.dinov2.requires_grad_(False)
128
  self.dinov2.eval()
129
+ if use_cuda:
130
+ self.dinov2 = self.dinov2.cuda()
131
 
132
  def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
133
  def attn_residual_func(x):
 
180
  outputs = []
181
  for i, image in enumerate(images):
182
  torch_image = transform(image)
183
+ if use_cuda:
184
+ torch_image = torch_image.cuda()
185
  attn_output, mlp_output, block_output = feat_extractor(
 
186
  torch_image.unsqueeze(0)
187
  )
188
  out_dict = {
 
207
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
208
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
209
  self.model = model.eval()
210
+ if use_cuda:
211
+ self.model = self.model.cuda()
212
+
213
  def new_forward(
214
  self,
215
  hidden_states: torch.Tensor,
 
286
  outputs = []
287
  for i, image in enumerate(images):
288
  torch_image = transform(image)
289
+ if use_cuda:
290
+ torch_image = torch_image.cuda()
291
  attn_output, mlp_output, block_output = feat_extractor(
 
292
  torch_image.unsqueeze(0)
293
  )
294
  out_dict = {
 
331
  eigvecs, eigvals = NCUT(
332
  num_eig=num_eig,
333
  num_sample=num_sample_ncut,
334
+ device="cuda" if use_cuda else "cpu",
335
  affinity_focal_gamma=affinity_focal_gamma,
336
  knn=knn_ncut,
337
  ).fit_transform(features.reshape(-1, features.shape[-1]))