grostaco commited on
Commit
9ff0cd2
·
1 Parent(s): 5aef4a3

feat: add segmentation

Browse files
Files changed (3) hide show
  1. app.py +25 -14
  2. lib/utils/model.py +53 -4
  3. pages/losses.py +42 -33
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
- from st_pages import Page, show_pages, add_page_title, Section
3
- from lib.utils.model import get_model, get_similarities
4
- from lib.utils.timer import timer
5
 
6
  add_page_title()
7
 
@@ -23,23 +23,31 @@ caption = st.text_input('Description Input')
23
 
24
  images = st.file_uploader('Upload images', accept_multiple_files=True)
25
  if images is not None:
26
-
27
- st.image(images) # type: ignore
28
 
29
  st.header('Options')
30
  st.subheader('Ranks', help='How many predictions the model is allowed to make')
31
 
32
- ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
 
 
 
 
33
 
34
- button = st.button('Match most similar', disabled=len(images) == 0 or caption == '')
35
 
36
  if button:
 
 
 
 
37
  st.header('Results')
38
  with st.spinner('Loading model'):
39
  model = get_model()
40
 
41
- st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
42
-
 
43
  time = timer()
44
  with st.spinner('Computing and ranking similarities'):
45
  with timer() as t:
@@ -47,15 +55,16 @@ if button:
47
  elapsed = t()
48
 
49
  indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
50
-
51
  c1, c2, c3 = st.columns(3)
52
  with c1:
53
  st.subheader('Rank')
54
  with c2:
55
  st.subheader('Image')
56
  with c3:
57
- st.subheader('Cosine Similarity', help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is')
58
-
 
59
  for i, idx in enumerate(indices):
60
  c1, c2, c3 = st.columns(3)
61
  with c1:
@@ -72,5 +81,7 @@ with st.sidebar:
72
 
73
  st.subheader('Useful Links')
74
  st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
75
- st.markdown('[IRRA implementation (Pytorch Lightning + Transformers)](https://github.com/grostaco/modern-IRRA)')
76
- st.markdown('[IRRA implementation (PyTorch)](https://github.com/anosorae/IRRA/tree/main)')
 
 
 
1
  import streamlit as st
2
+ from st_pages import Page, show_pages, add_page_title, Section
3
+ from lib.utils.model import get_model, get_similarities, get_detr, segment_images
4
+ from lib.utils.timer import timer
5
 
6
  add_page_title()
7
 
 
23
 
24
  images = st.file_uploader('Upload images', accept_multiple_files=True)
25
  if images is not None:
26
+
27
+ st.image(images) # type: ignore
28
 
29
  st.header('Options')
30
  st.subheader('Ranks', help='How many predictions the model is allowed to make')
31
 
32
+ ranks = st.slider('slider_ranks', min_value=1, max_value=10,
33
+ label_visibility='collapsed', value=5)
34
+ do_segment = st.checkbox('Segment images with DETR', value=False)
35
+ button = st.button('Match most similar', disabled=len(
36
+ images) == 0 or caption == '')
37
 
 
38
 
39
  if button:
40
+ if do_segment:
41
+ detr, processor = get_detr()
42
+ images = segment_images(detr, processor, images)
43
+
44
  st.header('Results')
45
  with st.spinner('Loading model'):
46
  model = get_model()
47
 
48
+ st.text(
49
+ f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
50
+
51
  time = timer()
52
  with st.spinner('Computing and ranking similarities'):
53
  with timer() as t:
 
55
  elapsed = t()
56
 
57
  indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
58
+
59
  c1, c2, c3 = st.columns(3)
60
  with c1:
61
  st.subheader('Rank')
62
  with c2:
63
  st.subheader('Image')
64
  with c3:
65
+ st.subheader('Cosine Similarity',
66
+ help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is')
67
+
68
  for i, idx in enumerate(indices):
69
  c1, c2, c3 = st.columns(3)
70
  with c1:
 
81
 
82
  st.subheader('Useful Links')
83
  st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
84
+ st.markdown(
85
+ '[IRRA implementation (Pytorch Lightning + Transformers)](https://github.com/grostaco/modern-IRRA)')
86
+ st.markdown(
87
+ '[IRRA implementation (PyTorch)](https://github.com/anosorae/IRRA/tree/main)')
lib/utils/model.py CHANGED
@@ -1,15 +1,20 @@
1
- import streamlit as st
2
  import yaml
3
  import torch
4
  import torch.nn.functional as F
5
 
 
 
6
  from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
7
  from lib.IRRA.image import prepare_images
8
  from lib.IRRA.model.build import build_model, IRRA
 
 
9
 
10
  from easydict import EasyDict
11
 
12
- @st.cache_resource
 
13
  def get_model():
14
  args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
15
  args = EasyDict(args)
@@ -17,7 +22,51 @@ def get_model():
17
 
18
  model = build_model(args)
19
 
20
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
23
  tokenizer = SimpleTokenizer()
@@ -30,5 +79,5 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
30
 
31
  image_feats = F.normalize(image_feats, p=2, dim=1)
32
  text_feats = F.normalize(text_feats, p=2, dim=1)
33
-
34
  return text_feats @ image_feats.t()
 
1
+ import streamlit as st
2
  import yaml
3
  import torch
4
  import torch.nn.functional as F
5
 
6
+ from transformers import DetrImageProcessor, DetrForObjectDetection
7
+
8
  from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
9
  from lib.IRRA.image import prepare_images
10
  from lib.IRRA.model.build import build_model, IRRA
11
+ from PIL import Image
12
+ from pathlib import Path
13
 
14
  from easydict import EasyDict
15
 
16
+
17
+ @st.cache_resource
18
  def get_model():
19
  args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
20
  args = EasyDict(args)
 
22
 
23
  model = build_model(args)
24
 
25
+ return model
26
+
27
+
28
+ @st.cache_resource
29
+ def get_detr():
30
+ processor = DetrImageProcessor.from_pretrained(
31
+ "facebook/detr-resnet-50", revision="no_timm")
32
+
33
+ model = DetrForObjectDetection.from_pretrained(
34
+ "facebook/detr-resnet-50", revision="no_timm")
35
+
36
+ return model, processor
37
+
38
+
39
+ def segment_images(model, processor, images: list[str]):
40
+ segments = []
41
+ id = 0
42
+
43
+ p = Path('segments')
44
+ p.mkdir(exist_ok=True)
45
+
46
+ for image in images:
47
+ image = Image.open(image)
48
+
49
+ inputs = processor(images=image, return_tensors="pt")
50
+ outputs = model(**inputs)
51
+
52
+ target_sizes = torch.tensor([image.size[::-1]])
53
+ results = processor.post_process_object_detection(
54
+ outputs, target_sizes=target_sizes, threshold=0.9)[0]
55
+
56
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
57
+ box = [round(i, 2) for i in box.tolist()]
58
+ label = model.config.id2label[label.item()]
59
+
60
+ if box[2] - box[0] > 70 and box[3] - box[1] > 70:
61
+ if label == 'person':
62
+ file = p / f'img_{id}.jpg'
63
+ image.crop(box).save(file)
64
+ segments.append(file.as_posix())
65
+
66
+ id += 1
67
+
68
+ return segments
69
+
70
 
71
  def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
72
  tokenizer = SimpleTokenizer()
 
79
 
80
  image_feats = F.normalize(image_feats, p=2, dim=1)
81
  text_feats = F.normalize(text_feats, p=2, dim=1)
82
+
83
  return text_feats @ image_feats.t()
pages/losses.py CHANGED
@@ -4,36 +4,45 @@ from st_pages import add_indentation
4
  add_indentation()
5
 
6
  st.title('Loss functions')
7
- st.subheader('SDM Loss')
8
- st.markdown('''
9
- The similarity distribution matching (SDM) loss, which is the KL divergence
10
- of the image to text and text to image to the label distribution.
11
-
12
- We define $f^v$ and $f^t$ to be the global representation of the visual and textual features respectively.
13
- The cosine similarity $sim(u, v) = \\frac{u \\cdot v}{|u||v|}$ will be used to compute the probability of the labels.
14
-
15
- We define $y_{i, j}=1$ if the visual feature $f^v_i$ matches the textual feature $f^t_j$, else $y_{i, j}=0$.
16
- The predicted label distribution can be formulated by''')
17
- st.latex(r'''
18
- p_{i} = \sigma(sim(f^v_i, f^t))
19
- ''')
20
-
21
- st.markdown('''
22
- We can define the image to text loss as
23
- ''')
24
-
25
- st.latex(r'''
26
- \mathcal{L}_{i2t} = KL(\mathbf{p_i} || \mathbf{q_i})
27
- ''')
28
-
29
- st.markdown('Where $\\mathbf{q_i}$, the true probability distribution, is defined as')
30
-
31
- st.latex(r'''
32
- q_{i, j} = \frac{y_{i, j}}{\sum_{k=1}^{N} y_{i, k}}
33
- ''')
34
-
35
- st.markdown('It should be noted that the reason this computation is needed is because there could be multiple correct labels.')
36
-
37
-
38
- st.subheader('IRR (MLM) Loss')
39
- st.subheader('ID Loss')
 
 
 
 
 
 
 
 
 
 
4
  add_indentation()
5
 
6
  st.title('Loss functions')
7
+ st.markdown('In order to align textual and visual features, multiple loss functions are employed. '
8
+ 'The most notable loss function was proposed in [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501) '
9
+ 'with the introduction of the SDM loss and the usage of the IRR (Implicit Reason Relations) loss.')
10
+ with st.expander('SDM Loss'):
11
+ st.markdown('''
12
+ The similarity distribution matching (SDM) loss, which is the KL divergence
13
+ of the image to text and text to image to the label distribution.
14
+
15
+ We define $f^v$ and $f^t$ to be the global representation of the visual and textual features respectively.
16
+ The cosine similarity $sim(u, v) = \\frac{u \\cdot v}{|u||v|}$ will be used to compute the probability of the labels.
17
+
18
+ We define $y_{i, j}=1$ if the visual feature $f^v_i$ matches the textual feature $f^t_j$, else $y_{i, j}=0$.
19
+ The predicted label distribution can be formulated by''')
20
+ st.latex(r'''
21
+ p_{i} = \sigma(sim(f^v_i, f^t))
22
+ ''')
23
+
24
+ st.markdown('''
25
+ We can define the image to text loss as
26
+ ''')
27
+
28
+ st.latex(r'''
29
+ \mathcal{L}_{i2t} = KL(\mathbf{p_i} || \mathbf{q_i})
30
+ ''')
31
+
32
+ st.markdown('Where $\\mathbf{q_i}$, the true probability distribution, is defined as')
33
+
34
+ st.latex(r'''
35
+ q_{i, j} = \frac{y_{i, j}}{\sum_{k=1}^{N} y_{i, k}}
36
+ ''')
37
+
38
+ st.markdown('It should be noted that the reason this computation is needed is because there could be multiple correct labels.')
39
+
40
+ st.markdown('The SDM loss can be formulated as')
41
+ st.latex(r'''
42
+ \mathcal{L}_{sdm} = \mathcal{L}_{i2t} + \mathcal{L}_{t2i}
43
+ ''')
44
+
45
+ with st.expander('IRR (MLM) Loss'):
46
+ ...
47
+ with st.expander('ID Loss'):
48
+ ...