feat: add segmentation
Browse files- app.py +25 -14
- lib/utils/model.py +53 -4
- pages/losses.py +42 -33
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
-
from st_pages import Page, show_pages, add_page_title, Section
|
3 |
-
from lib.utils.model import get_model, get_similarities
|
4 |
-
from lib.utils.timer import timer
|
5 |
|
6 |
add_page_title()
|
7 |
|
@@ -23,23 +23,31 @@ caption = st.text_input('Description Input')
|
|
23 |
|
24 |
images = st.file_uploader('Upload images', accept_multiple_files=True)
|
25 |
if images is not None:
|
26 |
-
|
27 |
-
st.image(images)
|
28 |
|
29 |
st.header('Options')
|
30 |
st.subheader('Ranks', help='How many predictions the model is allowed to make')
|
31 |
|
32 |
-
ranks = st.slider('slider_ranks', min_value=1, max_value=10,
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
button = st.button('Match most similar', disabled=len(images) == 0 or caption == '')
|
35 |
|
36 |
if button:
|
|
|
|
|
|
|
|
|
37 |
st.header('Results')
|
38 |
with st.spinner('Loading model'):
|
39 |
model = get_model()
|
40 |
|
41 |
-
st.text(
|
42 |
-
|
|
|
43 |
time = timer()
|
44 |
with st.spinner('Computing and ranking similarities'):
|
45 |
with timer() as t:
|
@@ -47,15 +55,16 @@ if button:
|
|
47 |
elapsed = t()
|
48 |
|
49 |
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
|
50 |
-
|
51 |
c1, c2, c3 = st.columns(3)
|
52 |
with c1:
|
53 |
st.subheader('Rank')
|
54 |
with c2:
|
55 |
st.subheader('Image')
|
56 |
with c3:
|
57 |
-
st.subheader('Cosine Similarity',
|
58 |
-
|
|
|
59 |
for i, idx in enumerate(indices):
|
60 |
c1, c2, c3 = st.columns(3)
|
61 |
with c1:
|
@@ -72,5 +81,7 @@ with st.sidebar:
|
|
72 |
|
73 |
st.subheader('Useful Links')
|
74 |
st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
|
75 |
-
st.markdown(
|
76 |
-
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from st_pages import Page, show_pages, add_page_title, Section
|
3 |
+
from lib.utils.model import get_model, get_similarities, get_detr, segment_images
|
4 |
+
from lib.utils.timer import timer
|
5 |
|
6 |
add_page_title()
|
7 |
|
|
|
23 |
|
24 |
images = st.file_uploader('Upload images', accept_multiple_files=True)
|
25 |
if images is not None:
|
26 |
+
|
27 |
+
st.image(images) # type: ignore
|
28 |
|
29 |
st.header('Options')
|
30 |
st.subheader('Ranks', help='How many predictions the model is allowed to make')
|
31 |
|
32 |
+
ranks = st.slider('slider_ranks', min_value=1, max_value=10,
|
33 |
+
label_visibility='collapsed', value=5)
|
34 |
+
do_segment = st.checkbox('Segment images with DETR', value=False)
|
35 |
+
button = st.button('Match most similar', disabled=len(
|
36 |
+
images) == 0 or caption == '')
|
37 |
|
|
|
38 |
|
39 |
if button:
|
40 |
+
if do_segment:
|
41 |
+
detr, processor = get_detr()
|
42 |
+
images = segment_images(detr, processor, images)
|
43 |
+
|
44 |
st.header('Results')
|
45 |
with st.spinner('Loading model'):
|
46 |
model = get_model()
|
47 |
|
48 |
+
st.text(
|
49 |
+
f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
|
50 |
+
|
51 |
time = timer()
|
52 |
with st.spinner('Computing and ranking similarities'):
|
53 |
with timer() as t:
|
|
|
55 |
elapsed = t()
|
56 |
|
57 |
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
|
58 |
+
|
59 |
c1, c2, c3 = st.columns(3)
|
60 |
with c1:
|
61 |
st.subheader('Rank')
|
62 |
with c2:
|
63 |
st.subheader('Image')
|
64 |
with c3:
|
65 |
+
st.subheader('Cosine Similarity',
|
66 |
+
help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is')
|
67 |
+
|
68 |
for i, idx in enumerate(indices):
|
69 |
c1, c2, c3 = st.columns(3)
|
70 |
with c1:
|
|
|
81 |
|
82 |
st.subheader('Useful Links')
|
83 |
st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
|
84 |
+
st.markdown(
|
85 |
+
'[IRRA implementation (Pytorch Lightning + Transformers)](https://github.com/grostaco/modern-IRRA)')
|
86 |
+
st.markdown(
|
87 |
+
'[IRRA implementation (PyTorch)](https://github.com/anosorae/IRRA/tree/main)')
|
lib/utils/model.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1 |
-
import streamlit as st
|
2 |
import yaml
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
|
|
|
|
|
6 |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
|
7 |
from lib.IRRA.image import prepare_images
|
8 |
from lib.IRRA.model.build import build_model, IRRA
|
|
|
|
|
9 |
|
10 |
from easydict import EasyDict
|
11 |
|
12 |
-
|
|
|
13 |
def get_model():
|
14 |
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
|
15 |
args = EasyDict(args)
|
@@ -17,7 +22,51 @@ def get_model():
|
|
17 |
|
18 |
model = build_model(args)
|
19 |
|
20 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
23 |
tokenizer = SimpleTokenizer()
|
@@ -30,5 +79,5 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
|
30 |
|
31 |
image_feats = F.normalize(image_feats, p=2, dim=1)
|
32 |
text_feats = F.normalize(text_feats, p=2, dim=1)
|
33 |
-
|
34 |
return text_feats @ image_feats.t()
|
|
|
1 |
+
import streamlit as st
|
2 |
import yaml
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
|
6 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
7 |
+
|
8 |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
|
9 |
from lib.IRRA.image import prepare_images
|
10 |
from lib.IRRA.model.build import build_model, IRRA
|
11 |
+
from PIL import Image
|
12 |
+
from pathlib import Path
|
13 |
|
14 |
from easydict import EasyDict
|
15 |
|
16 |
+
|
17 |
+
@st.cache_resource
|
18 |
def get_model():
|
19 |
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
|
20 |
args = EasyDict(args)
|
|
|
22 |
|
23 |
model = build_model(args)
|
24 |
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
@st.cache_resource
|
29 |
+
def get_detr():
|
30 |
+
processor = DetrImageProcessor.from_pretrained(
|
31 |
+
"facebook/detr-resnet-50", revision="no_timm")
|
32 |
+
|
33 |
+
model = DetrForObjectDetection.from_pretrained(
|
34 |
+
"facebook/detr-resnet-50", revision="no_timm")
|
35 |
+
|
36 |
+
return model, processor
|
37 |
+
|
38 |
+
|
39 |
+
def segment_images(model, processor, images: list[str]):
|
40 |
+
segments = []
|
41 |
+
id = 0
|
42 |
+
|
43 |
+
p = Path('segments')
|
44 |
+
p.mkdir(exist_ok=True)
|
45 |
+
|
46 |
+
for image in images:
|
47 |
+
image = Image.open(image)
|
48 |
+
|
49 |
+
inputs = processor(images=image, return_tensors="pt")
|
50 |
+
outputs = model(**inputs)
|
51 |
+
|
52 |
+
target_sizes = torch.tensor([image.size[::-1]])
|
53 |
+
results = processor.post_process_object_detection(
|
54 |
+
outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
55 |
+
|
56 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
57 |
+
box = [round(i, 2) for i in box.tolist()]
|
58 |
+
label = model.config.id2label[label.item()]
|
59 |
+
|
60 |
+
if box[2] - box[0] > 70 and box[3] - box[1] > 70:
|
61 |
+
if label == 'person':
|
62 |
+
file = p / f'img_{id}.jpg'
|
63 |
+
image.crop(box).save(file)
|
64 |
+
segments.append(file.as_posix())
|
65 |
+
|
66 |
+
id += 1
|
67 |
+
|
68 |
+
return segments
|
69 |
+
|
70 |
|
71 |
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
72 |
tokenizer = SimpleTokenizer()
|
|
|
79 |
|
80 |
image_feats = F.normalize(image_feats, p=2, dim=1)
|
81 |
text_feats = F.normalize(text_feats, p=2, dim=1)
|
82 |
+
|
83 |
return text_feats @ image_feats.t()
|
pages/losses.py
CHANGED
@@ -4,36 +4,45 @@ from st_pages import add_indentation
|
|
4 |
add_indentation()
|
5 |
|
6 |
st.title('Loss functions')
|
7 |
-
st.
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
''')
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
st.
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
add_indentation()
|
5 |
|
6 |
st.title('Loss functions')
|
7 |
+
st.markdown('In order to align textual and visual features, multiple loss functions are employed. '
|
8 |
+
'The most notable loss function was proposed in [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501) '
|
9 |
+
'with the introduction of the SDM loss and the usage of the IRR (Implicit Reason Relations) loss.')
|
10 |
+
with st.expander('SDM Loss'):
|
11 |
+
st.markdown('''
|
12 |
+
The similarity distribution matching (SDM) loss, which is the KL divergence
|
13 |
+
of the image to text and text to image to the label distribution.
|
14 |
+
|
15 |
+
We define $f^v$ and $f^t$ to be the global representation of the visual and textual features respectively.
|
16 |
+
The cosine similarity $sim(u, v) = \\frac{u \\cdot v}{|u||v|}$ will be used to compute the probability of the labels.
|
17 |
+
|
18 |
+
We define $y_{i, j}=1$ if the visual feature $f^v_i$ matches the textual feature $f^t_j$, else $y_{i, j}=0$.
|
19 |
+
The predicted label distribution can be formulated by''')
|
20 |
+
st.latex(r'''
|
21 |
+
p_{i} = \sigma(sim(f^v_i, f^t))
|
22 |
+
''')
|
23 |
+
|
24 |
+
st.markdown('''
|
25 |
+
We can define the image to text loss as
|
26 |
+
''')
|
27 |
+
|
28 |
+
st.latex(r'''
|
29 |
+
\mathcal{L}_{i2t} = KL(\mathbf{p_i} || \mathbf{q_i})
|
30 |
+
''')
|
31 |
+
|
32 |
+
st.markdown('Where $\\mathbf{q_i}$, the true probability distribution, is defined as')
|
33 |
+
|
34 |
+
st.latex(r'''
|
35 |
+
q_{i, j} = \frac{y_{i, j}}{\sum_{k=1}^{N} y_{i, k}}
|
36 |
+
''')
|
37 |
+
|
38 |
+
st.markdown('It should be noted that the reason this computation is needed is because there could be multiple correct labels.')
|
39 |
+
|
40 |
+
st.markdown('The SDM loss can be formulated as')
|
41 |
+
st.latex(r'''
|
42 |
+
\mathcal{L}_{sdm} = \mathcal{L}_{i2t} + \mathcal{L}_{t2i}
|
43 |
+
''')
|
44 |
+
|
45 |
+
with st.expander('IRR (MLM) Loss'):
|
46 |
+
...
|
47 |
+
with st.expander('ID Loss'):
|
48 |
+
...
|