Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse files- README.md +44 -13
- app.py +470 -0
- images/image_0.jpg +0 -0
- images/image_1.jpg +0 -0
- images/image_2.jpg +0 -0
- images/image_3.jpg +0 -0
- images/image_4.jpg +0 -0
- images/image_5.jpg +0 -0
- images/ncut_0.jpg +0 -0
- images/ncut_1.jpg +0 -0
- images/ncut_2.jpg +0 -0
- images/ncut_3.jpg +0 -0
- images/ncut_4.jpg +0 -0
- images/ncut_5.jpg +0 -0
- requirements.txt +3 -0
README.md
CHANGED
@@ -1,13 +1,44 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
Documentation [https://ncut-pytorch.readthedocs.io/](https://ncut-pytorch.readthedocs.io/)
|
5 |
+
|
6 |
+
|
7 |
+
## NCUT: Nyström Normalized Cut
|
8 |
+
|
9 |
+
**Normalized Cut**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.
|
10 |
+
|
11 |
+
**Nyström Normalized Cut**, is a new approximation algorithm developed for large-scale graph cuts, a large-graph of million nodes can be processed in under 10s (cpu) or 2s (gpu).
|
12 |
+
|
13 |
+
## Gallery
|
14 |
+
TODO
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
PyPI install, our package is based on [PyTorch](https://pytorch.org/get-started/locally/), presuming you already have PyTorch installed
|
19 |
+
|
20 |
+
```shell
|
21 |
+
pip install ncut-pytorch
|
22 |
+
```
|
23 |
+
|
24 |
+
[Install PyTorch](https://pytorch.org/get-started/locally/) if you haven't
|
25 |
+
```shell
|
26 |
+
pip install torch
|
27 |
+
```
|
28 |
+
## Why NCUT
|
29 |
+
|
30 |
+
Normalized cut offers two advantages:
|
31 |
+
|
32 |
+
1. soft-cluster assignments as eigenvectors
|
33 |
+
|
34 |
+
2. hierarchical clustering by varying the number of eigenvectors
|
35 |
+
|
36 |
+
Please see [NCUT and t-SNE/UMAP](compare.md) for a full comparison.
|
37 |
+
|
38 |
+
|
39 |
+
> paper in prep, Yang 2024
|
40 |
+
>
|
41 |
+
> AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, Jianbo Shi\*, 2024
|
42 |
+
>
|
43 |
+
> Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000
|
44 |
+
>
|
app.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
|
13 |
+
class SAM(torch.nn.Module):
|
14 |
+
def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
|
15 |
+
super().__init__(**kwargs)
|
16 |
+
from segment_anything import sam_model_registry, SamPredictor
|
17 |
+
from segment_anything.modeling.sam import Sam
|
18 |
+
|
19 |
+
sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
|
20 |
+
|
21 |
+
from segment_anything.modeling.image_encoder import (
|
22 |
+
window_partition,
|
23 |
+
window_unpartition,
|
24 |
+
)
|
25 |
+
|
26 |
+
def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
|
27 |
+
shortcut = x
|
28 |
+
x = self.norm1(x)
|
29 |
+
# Window partition
|
30 |
+
if self.window_size > 0:
|
31 |
+
H, W = x.shape[1], x.shape[2]
|
32 |
+
x, pad_hw = window_partition(x, self.window_size)
|
33 |
+
|
34 |
+
x = self.attn(x)
|
35 |
+
# Reverse window partition
|
36 |
+
if self.window_size > 0:
|
37 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
38 |
+
self.attn_output = x.clone()
|
39 |
+
|
40 |
+
x = shortcut + x
|
41 |
+
mlp_outout = self.mlp(self.norm2(x))
|
42 |
+
self.mlp_output = mlp_outout.clone()
|
43 |
+
x = x + mlp_outout
|
44 |
+
self.block_output = x.clone()
|
45 |
+
|
46 |
+
return x
|
47 |
+
|
48 |
+
setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
|
49 |
+
|
50 |
+
self.image_encoder = sam.image_encoder
|
51 |
+
self.image_encoder.eval()
|
52 |
+
# self.image_encoder = self.image_encoder.cuda()
|
53 |
+
|
54 |
+
@torch.no_grad()
|
55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56 |
+
with torch.no_grad():
|
57 |
+
x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
|
58 |
+
out = self.image_encoder(x)
|
59 |
+
|
60 |
+
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
61 |
+
for i, blk in enumerate(self.image_encoder.blocks):
|
62 |
+
attn_outputs.append(blk.attn_output)
|
63 |
+
mlp_outputs.append(blk.mlp_output)
|
64 |
+
block_outputs.append(blk.block_output)
|
65 |
+
attn_outputs = torch.stack(attn_outputs)
|
66 |
+
mlp_outputs = torch.stack(mlp_outputs)
|
67 |
+
block_outputs = torch.stack(block_outputs)
|
68 |
+
return attn_outputs, mlp_outputs, block_outputs
|
69 |
+
|
70 |
+
|
71 |
+
def image_sam_feature(
|
72 |
+
images,
|
73 |
+
resolution=(1024, 1024),
|
74 |
+
node_type="block",
|
75 |
+
layer=-1,
|
76 |
+
):
|
77 |
+
|
78 |
+
transform = transforms.Compose(
|
79 |
+
[
|
80 |
+
transforms.Resize(resolution),
|
81 |
+
transforms.ToTensor(),
|
82 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
83 |
+
]
|
84 |
+
)
|
85 |
+
|
86 |
+
checkpoint = "sam_vit_b_01ec64.pth"
|
87 |
+
if not os.path.exists(checkpoint):
|
88 |
+
checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
|
89 |
+
import requests
|
90 |
+
r = requests.get(checkpoint_url)
|
91 |
+
with open(checkpoint, 'wb') as f:
|
92 |
+
f.write(r.content)
|
93 |
+
|
94 |
+
feat_extractor = SAM(checkpoint=checkpoint)
|
95 |
+
|
96 |
+
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
97 |
+
outputs = []
|
98 |
+
for i, image in enumerate(images):
|
99 |
+
torch_image = transform(image)
|
100 |
+
attn_output, mlp_output, block_output = feat_extractor(
|
101 |
+
# torch_image.unsqueeze(0).cuda()
|
102 |
+
torch_image.unsqueeze(0)
|
103 |
+
)
|
104 |
+
out_dict = {
|
105 |
+
"attn": attn_output,
|
106 |
+
"mlp": mlp_output,
|
107 |
+
"block": block_output,
|
108 |
+
}
|
109 |
+
out = out_dict[node_type]
|
110 |
+
out = out[layer]
|
111 |
+
outputs.append(out.cpu())
|
112 |
+
outputs = torch.cat(outputs, dim=0)
|
113 |
+
return outputs
|
114 |
+
|
115 |
+
|
116 |
+
class DiNOv2(torch.nn.Module):
|
117 |
+
def __init__(self, ver="dinov2_vitb14_reg"):
|
118 |
+
super().__init__()
|
119 |
+
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
|
120 |
+
self.dinov2.requires_grad_(False)
|
121 |
+
self.dinov2.eval()
|
122 |
+
# self.dinov2 = self.dinov2.cuda()
|
123 |
+
|
124 |
+
def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
|
125 |
+
def attn_residual_func(x):
|
126 |
+
return self.ls1(self.attn(self.norm1(x)))
|
127 |
+
|
128 |
+
def ffn_residual_func(x):
|
129 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
130 |
+
|
131 |
+
attn_output = attn_residual_func(x)
|
132 |
+
self.attn_output = attn_output.clone()
|
133 |
+
x = x + attn_output
|
134 |
+
mlp_output = ffn_residual_func(x)
|
135 |
+
self.mlp_output = mlp_output.clone()
|
136 |
+
x = x + mlp_output
|
137 |
+
block_output = x
|
138 |
+
self.block_output = block_output.clone()
|
139 |
+
return x
|
140 |
+
|
141 |
+
setattr(self.dinov2.blocks[0].__class__, "forward", new_block_forward)
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def forward(self, x):
|
145 |
+
|
146 |
+
out = self.dinov2(x)
|
147 |
+
|
148 |
+
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
149 |
+
for i, blk in enumerate(self.dinov2.blocks):
|
150 |
+
attn_outputs.append(blk.attn_output)
|
151 |
+
mlp_outputs.append(blk.mlp_output)
|
152 |
+
block_outputs.append(blk.block_output)
|
153 |
+
|
154 |
+
attn_outputs = torch.stack(attn_outputs)
|
155 |
+
mlp_outputs = torch.stack(mlp_outputs)
|
156 |
+
block_outputs = torch.stack(block_outputs)
|
157 |
+
return attn_outputs, mlp_outputs, block_outputs
|
158 |
+
|
159 |
+
|
160 |
+
def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
|
161 |
+
|
162 |
+
transform = transforms.Compose(
|
163 |
+
[
|
164 |
+
transforms.Resize(resolution),
|
165 |
+
transforms.ToTensor(),
|
166 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
167 |
+
]
|
168 |
+
)
|
169 |
+
|
170 |
+
feat_extractor = DiNOv2()
|
171 |
+
|
172 |
+
outputs = []
|
173 |
+
for i, image in enumerate(images):
|
174 |
+
torch_image = transform(image)
|
175 |
+
attn_output, mlp_output, block_output = feat_extractor(
|
176 |
+
# torch_image.unsqueeze(0).cuda()
|
177 |
+
torch_image.unsqueeze(0)
|
178 |
+
)
|
179 |
+
out_dict = {
|
180 |
+
"attn": attn_output,
|
181 |
+
"mlp": mlp_output,
|
182 |
+
"block": block_output,
|
183 |
+
}
|
184 |
+
out = out_dict[node_type]
|
185 |
+
out = out[layer]
|
186 |
+
outputs.append(out.cpu())
|
187 |
+
outputs = torch.cat(outputs, dim=0)
|
188 |
+
outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
|
189 |
+
return outputs
|
190 |
+
|
191 |
+
|
192 |
+
class CLIP(torch.nn.Module):
|
193 |
+
def __init__(self):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
from transformers import CLIPProcessor, CLIPModel
|
197 |
+
|
198 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
199 |
+
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
200 |
+
self.model = model.eval()
|
201 |
+
# self.model = self.model.cuda()
|
202 |
+
|
203 |
+
def new_forward(
|
204 |
+
self,
|
205 |
+
hidden_states: torch.Tensor,
|
206 |
+
attention_mask: torch.Tensor,
|
207 |
+
causal_attention_mask: torch.Tensor,
|
208 |
+
output_attentions: Optional[bool] = False,
|
209 |
+
) -> Tuple[torch.FloatTensor]:
|
210 |
+
|
211 |
+
residual = hidden_states
|
212 |
+
|
213 |
+
hidden_states = self.layer_norm1(hidden_states)
|
214 |
+
hidden_states, attn_weights = self.self_attn(
|
215 |
+
hidden_states=hidden_states,
|
216 |
+
attention_mask=attention_mask,
|
217 |
+
causal_attention_mask=causal_attention_mask,
|
218 |
+
output_attentions=output_attentions,
|
219 |
+
)
|
220 |
+
self.attn_output = hidden_states.clone()
|
221 |
+
hidden_states = residual + hidden_states
|
222 |
+
|
223 |
+
residual = hidden_states
|
224 |
+
hidden_states = self.layer_norm2(hidden_states)
|
225 |
+
hidden_states = self.mlp(hidden_states)
|
226 |
+
self.mlp_output = hidden_states.clone()
|
227 |
+
|
228 |
+
hidden_states = residual + hidden_states
|
229 |
+
|
230 |
+
outputs = (hidden_states,)
|
231 |
+
|
232 |
+
if output_attentions:
|
233 |
+
outputs += (attn_weights,)
|
234 |
+
|
235 |
+
self.block_output = hidden_states.clone()
|
236 |
+
return outputs
|
237 |
+
|
238 |
+
setattr(self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward)
|
239 |
+
|
240 |
+
@torch.no_grad()
|
241 |
+
def forward(self, x):
|
242 |
+
|
243 |
+
out = self.model.vision_model(x)
|
244 |
+
|
245 |
+
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
246 |
+
for i, blk in enumerate(self.model.vision_model.encoder.layers):
|
247 |
+
attn_outputs.append(blk.attn_output)
|
248 |
+
mlp_outputs.append(blk.mlp_output)
|
249 |
+
block_outputs.append(blk.block_output)
|
250 |
+
|
251 |
+
attn_outputs = torch.stack(attn_outputs)
|
252 |
+
mlp_outputs = torch.stack(mlp_outputs)
|
253 |
+
block_outputs = torch.stack(block_outputs)
|
254 |
+
return attn_outputs, mlp_outputs, block_outputs
|
255 |
+
|
256 |
+
|
257 |
+
def image_clip_feature(
|
258 |
+
images, resolution=(224, 224), node_type="block", layer=-1
|
259 |
+
):
|
260 |
+
if isinstance(images, list):
|
261 |
+
assert isinstance(images[0], Image.Image), "Input must be a list of PIL images."
|
262 |
+
else:
|
263 |
+
assert isinstance(images, Image.Image), "Input must be a PIL image."
|
264 |
+
images = [images]
|
265 |
+
|
266 |
+
transform = transforms.Compose(
|
267 |
+
[
|
268 |
+
transforms.Resize(resolution),
|
269 |
+
transforms.ToTensor(),
|
270 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
271 |
+
]
|
272 |
+
)
|
273 |
+
|
274 |
+
feat_extractor = CLIP()
|
275 |
+
|
276 |
+
outputs = []
|
277 |
+
for i, image in enumerate(images):
|
278 |
+
torch_image = transform(image)
|
279 |
+
attn_output, mlp_output, block_output = feat_extractor(
|
280 |
+
# torch_image.unsqueeze(0).cuda()
|
281 |
+
torch_image.unsqueeze(0)
|
282 |
+
)
|
283 |
+
out_dict = {
|
284 |
+
"attn": attn_output,
|
285 |
+
"mlp": mlp_output,
|
286 |
+
"block": block_output,
|
287 |
+
}
|
288 |
+
out = out_dict[node_type]
|
289 |
+
out = out[layer]
|
290 |
+
outputs.append(out.cpu())
|
291 |
+
outputs = torch.cat(outputs, dim=0)
|
292 |
+
return outputs
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
def extract_features(images, model_name="sam", node_type="block", layer=-1):
|
297 |
+
if model_name == "SAM(sam_vit_b)":
|
298 |
+
return image_sam_feature(images, node_type=node_type, layer=layer)
|
299 |
+
elif model_name == "DiNO(dinov2_vitb14_reg)":
|
300 |
+
return image_dino_feature(images, node_type=node_type, layer=layer)
|
301 |
+
elif model_name == "CLIP(openai/clip-vit-base-patch16)":
|
302 |
+
return image_clip_feature(images, node_type=node_type, layer=layer)
|
303 |
+
else:
|
304 |
+
raise ValueError(f"Model {model_name} not supported.")
|
305 |
+
|
306 |
+
|
307 |
+
def compute_ncut(
|
308 |
+
features,
|
309 |
+
num_eig=100,
|
310 |
+
num_sample_ncut=10000,
|
311 |
+
affinity_focal_gamma=0.3,
|
312 |
+
knn_ncut=10,
|
313 |
+
knn_tsne=10,
|
314 |
+
num_sample_tsne=1000,
|
315 |
+
perplexity=500,
|
316 |
+
):
|
317 |
+
from ncut_pytorch import NCUT, rgb_from_tsne_3d
|
318 |
+
|
319 |
+
eigvecs, eigvals = NCUT(
|
320 |
+
num_eig=num_eig,
|
321 |
+
num_sample=num_sample_ncut,
|
322 |
+
# device="cuda:0",
|
323 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
324 |
+
knn=knn_ncut,
|
325 |
+
).fit_transform(features.reshape(-1, features.shape[-1]))
|
326 |
+
X_3d, rgb = rgb_from_tsne_3d(
|
327 |
+
eigvecs,
|
328 |
+
num_sample=num_sample_tsne,
|
329 |
+
perplexity=perplexity,
|
330 |
+
knn=knn_tsne,
|
331 |
+
)
|
332 |
+
rgb = rgb.reshape(features.shape[:3] + (3,))
|
333 |
+
return rgb
|
334 |
+
|
335 |
+
|
336 |
+
def dont_use_too_much_green(image_rgb):
|
337 |
+
# make sure the foval 40% of the image is red leading
|
338 |
+
x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
|
339 |
+
y1, y2 = int(image_rgb.shape[2] * 0.3), int(image_rgb.shape[2] * 0.7)
|
340 |
+
sum_values = image_rgb[:, x1:x2, y1:y2].mean((0, 1, 2))
|
341 |
+
sorted_indices = sum_values.argsort(descending=True)
|
342 |
+
image_rgb = image_rgb[:, :, :, sorted_indices]
|
343 |
+
return image_rgb
|
344 |
+
|
345 |
+
|
346 |
+
def to_pil_images(images):
|
347 |
+
return [
|
348 |
+
Image.fromarray((image * 255).cpu().numpy().astype(np.uint8)).resize((256, 256), Image.NEAREST)
|
349 |
+
for image in images
|
350 |
+
]
|
351 |
+
|
352 |
+
|
353 |
+
def main_fn(
|
354 |
+
images,
|
355 |
+
model_name="SAM(sam_vit_b)",
|
356 |
+
node_type="block",
|
357 |
+
layer=-1,
|
358 |
+
num_eig=100,
|
359 |
+
affinity_focal_gamma=0.3,
|
360 |
+
num_sample_ncut=10000,
|
361 |
+
knn_ncut=10,
|
362 |
+
num_sample_tsne=1000,
|
363 |
+
knn_tsne=10,
|
364 |
+
perplexity=500,
|
365 |
+
):
|
366 |
+
if perplexity >= num_sample_tsne:
|
367 |
+
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
368 |
+
gr.Warning("Perplexity must be less than the number of samples for t-SNE.\n" f"Setting perplexity to {num_sample_tsne-1}.")
|
369 |
+
perplexity = num_sample_tsne - 1
|
370 |
+
|
371 |
+
images = [image[0] for image in images]
|
372 |
+
features = extract_features(
|
373 |
+
images, model_name=model_name, node_type=node_type, layer=layer
|
374 |
+
)
|
375 |
+
rgb = compute_ncut(
|
376 |
+
features,
|
377 |
+
num_eig=num_eig,
|
378 |
+
num_sample_ncut=num_sample_ncut,
|
379 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
380 |
+
knn_ncut=knn_ncut,
|
381 |
+
knn_tsne=knn_tsne,
|
382 |
+
num_sample_tsne=num_sample_tsne,
|
383 |
+
perplexity=perplexity,
|
384 |
+
)
|
385 |
+
rgb = dont_use_too_much_green(rgb)
|
386 |
+
return to_pil_images(rgb)
|
387 |
+
|
388 |
+
|
389 |
+
default_images = ['/workspace/output/gradio/image_0.jpg', '/workspace/output/gradio/image_1.jpg', '/workspace/output/gradio/image_2.jpg', '/workspace/output/gradio/image_3.jpg', '/workspace/output/gradio/image_4.jpg', '/workspace/output/gradio/image_5.jpg']
|
390 |
+
default_outputs = ['/workspace/output/gradio/ncut_0.jpg', '/workspace/output/gradio/ncut_1.jpg', '/workspace/output/gradio/ncut_2.jpg', '/workspace/output/gradio/ncut_3.jpg', '/workspace/output/gradio/ncut_4.jpg', '/workspace/output/gradio/ncut_5.jpg']
|
391 |
+
|
392 |
+
demo = gr.Interface(
|
393 |
+
main_fn,
|
394 |
+
[
|
395 |
+
gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
|
396 |
+
gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
|
397 |
+
gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
|
398 |
+
gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
|
399 |
+
gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
|
400 |
+
gr.Slider(0.01, 1, step=0.01, label="Affinity focal gamma", value=0.3, elem_id="affinity_focal_gamma", info="decrease for more aggressive cleaning on the affinity matrix"),
|
401 |
+
],
|
402 |
+
gr.Gallery(value=default_outputs, label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto"),
|
403 |
+
additional_inputs=[
|
404 |
+
gr.Slider(100, 30000, step=100, label="num_sample (NCUT)", value=10000, elem_id="num_sample_ncut", info="for Nyström approximation"),
|
405 |
+
gr.Slider(1, 100, step=1, label="KNN (NCUT)", value=10, elem_id="knn_ncut", info="for Nyström approximation"),
|
406 |
+
gr.Slider(100, 10000, step=100, label="num_sample (t-SNE)", value=1000, elem_id="num_sample_tsne", info="for Nyström approximation. Adding will slow down t-SNE quite a lot"),
|
407 |
+
gr.Slider(1, 100, step=1, label="KNN (t-SNE)", value=10, elem_id="knn_tsne", info="for Nyström approximation"),
|
408 |
+
gr.Slider(10, 1000, step=10, label="Perplexity (t-SNE)", value=500, elem_id="perplexity", info="for t-SNE"),
|
409 |
+
|
410 |
+
]
|
411 |
+
)
|
412 |
+
|
413 |
+
demo.launch(share=True)
|
414 |
+
|
415 |
+
# %%
|
416 |
+
|
417 |
+
|
418 |
+
# # %%
|
419 |
+
# from ncut_pytorch import NCUT, rgb_from_tsne_3d
|
420 |
+
|
421 |
+
# i_layer = -1
|
422 |
+
# inp = block_outputs[i_layer]
|
423 |
+
# eigvecs, eigvals = NCUT(
|
424 |
+
# num_eig=1000, num_sample=10000, device="cuda:0", affinity_focal_gamma=0.3, knn=10
|
425 |
+
# ).fit_transform(inp.reshape(-1, inp.shape[-1]))
|
426 |
+
# print(eigvecs.shape, eigvals.shape)
|
427 |
+
# # %%
|
428 |
+
# X_3d, rgb = rgb_from_tsne_3d(
|
429 |
+
# eigvecs[:, :100], num_sample=1000, perplexity=500, knn=10, seed=42
|
430 |
+
# )
|
431 |
+
# # %%
|
432 |
+
# image_rgb = rgb.reshape(*inp.shape[:-1], 3)
|
433 |
+
# # make sure the foval 20% of the image is red leading
|
434 |
+
# x1, x2 = int(image_rgb.shape[1] * 0.4), int(image_rgb.shape[1] * 0.6)
|
435 |
+
# y1, y2 = int(image_rgb.shape[2] * 0.4), int(image_rgb.shape[2] * 0.6)
|
436 |
+
# sum_values = image_rgb[:, x1:x2, y1:y2].mean((0, 1, 2))
|
437 |
+
# sorted_indices = sum_values.argsort(descending=True)
|
438 |
+
# image_rgb = image_rgb[:, :, :, sorted_indices]
|
439 |
+
|
440 |
+
# import matplotlib.pyplot as plt
|
441 |
+
|
442 |
+
# fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
443 |
+
# for i, ax in enumerate(axes.flat):
|
444 |
+
# ax.imshow(image_rgb[i])
|
445 |
+
# ax.axis("off")
|
446 |
+
|
447 |
+
# %%
|
448 |
+
save_dir = "/workspace/output/gradio"
|
449 |
+
import os
|
450 |
+
|
451 |
+
os.makedirs(save_dir, exist_ok=True)
|
452 |
+
|
453 |
+
images = ['/workspace/guitars/lespual1.png', '/workspace/guitars/lespual2.png', '/workspace/guitars/lespual3.png', '/workspace/guitars/lespual4.png', '/workspace/guitars/lespual5.png', '/workspace/guitars/acoustic1.png']
|
454 |
+
images = [Image.open(image).convert("RGB") for image in images]
|
455 |
+
for i, image in enumerate(images):
|
456 |
+
image = image.resize((512, 512))
|
457 |
+
image.save(os.path.join(save_dir, f"image_{i}.jpg"), "JPEG", quality=70)
|
458 |
+
# %%
|
459 |
+
images = [(image, '') for image in images]
|
460 |
+
image_rbg = main_fn(images)
|
461 |
+
# %%
|
462 |
+
for i, rgb in enumerate(image_rbg):
|
463 |
+
rgb = rgb.resize((512, 512), Image.NEAREST)
|
464 |
+
rgb.save(os.path.join(save_dir, f"ncut_{i}.jpg"), "JPEG", quality=70)
|
465 |
+
# %%
|
466 |
+
for i, rgb in enumerate(image_rgb):
|
467 |
+
rgb = Image.fromarray((rgb * 255).cpu().numpy().astype(np.uint8))
|
468 |
+
rgb.save(os.path.join(save_dir, f"ncut_{i}.png"))
|
469 |
+
# %%
|
470 |
+
%%
|
images/image_0.jpg
ADDED
images/image_1.jpg
ADDED
images/image_2.jpg
ADDED
images/image_3.jpg
ADDED
images/image_4.jpg
ADDED
images/image_5.jpg
ADDED
images/ncut_0.jpg
ADDED
images/ncut_1.jpg
ADDED
images/ncut_2.jpg
ADDED
images/ncut_3.jpg
ADDED
images/ncut_4.jpg
ADDED
images/ncut_5.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ncut-pytorch
|
2 |
+
transformers
|
3 |
+
segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
|