Vivien commited on
Commit
74e4bcd
·
1 Parent(s): 383bcb1

Create app

Browse files
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode/
README.md CHANGED
@@ -1,11 +1,12 @@
1
  ---
2
- title: Clip Slip
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: streamlit
 
7
  app_file: app.py
8
- pinned: false
9
  ---
10
 
11
  # Configuration
 
1
  ---
2
+ title: Comparing CLIP and SLIP
3
+ emoji: 🖼️
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: streamlit
7
+ sdk_version: 1.0.0
8
  app_file: app.py
9
+ pinned: true
10
  ---
11
 
12
  # Configuration
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ from collections import OrderedDict
4
+ from html import escape
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+
12
+ from transformers import CLIPProcessor, CLIPModel
13
+ import tokenizers
14
+ import regex
15
+
16
+ import streamlit as st
17
+
18
+ import models
19
+ from tokenizer import SimpleTokenizer
20
+
21
+ cuda_available = torch.cuda.is_available()
22
+
23
+ model_url = "https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt"
24
+ model_filename = "slip_large_100ep.pt"
25
+
26
+
27
+ def get_model(model):
28
+ if isinstance(model, torch.nn.DataParallel) or isinstance(
29
+ model, torch.nn.parallel.DistributedDataParallel
30
+ ):
31
+ return model.module
32
+ else:
33
+ return model
34
+
35
+
36
+ @st.cache(
37
+ show_spinner=False,
38
+ hash_funcs={
39
+ CLIPModel: lambda _: None,
40
+ CLIPProcessor: lambda _: None,
41
+ dict: lambda _: None,
42
+ },
43
+ )
44
+ def load():
45
+ # Load SLIP model from Facebook AI Research
46
+ if model_filename not in os.listdir():
47
+ urllib.request.urlretrieve(model_url, model_filename)
48
+ ckpt = torch.load("slip_large_100ep.pt", map_location="cpu")
49
+ state_dict = OrderedDict()
50
+ for k, v in ckpt["state_dict"].items():
51
+ state_dict[k.replace("module.", "")] = v
52
+ old_args = ckpt["args"]
53
+ slip_model = getattr(models, "SLIP_VITL16")(
54
+ rand_embed=False,
55
+ ssl_mlp_dim=old_args.ssl_mlp_dim,
56
+ ssl_emb_dim=old_args.ssl_emb_dim,
57
+ )
58
+ if cuda_available:
59
+ slip_model.cuda()
60
+ slip_model.load_state_dict(state_dict, strict=True)
61
+ slip_model = get_model(slip_model)
62
+ tokenizer = SimpleTokenizer()
63
+ del ckpt
64
+ del state_dict
65
+ # Load CLIP model from HuggingFace
66
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
67
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
68
+ # Load images' descriptions and embeddings
69
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
70
+ embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
71
+ slip_embeddings = {
72
+ 0: np.load("embeddings_slip_large.npy"),
73
+ 1: np.load("embeddings2_slip_large.npy"),
74
+ }
75
+ for k in [0, 1]:
76
+ embeddings[k] = np.divide(
77
+ embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
78
+ )
79
+ return model, processor, slip_model, tokenizer, df, embeddings, slip_embeddings
80
+
81
+
82
+ model, processor, slip_model, tokenizer, df, embeddings, slip_embeddings = load()
83
+
84
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
85
+
86
+
87
+ def get_html(url_list, url_list_slip, height=150):
88
+ html = "<div style='display: flex; flex-wrap: wrap; justify-content: space-evenly;'>"
89
+ html += "<span style='margin-top: 20px; max-width: 1200px; display: flex; align-content: flex-start; flex-wrap: wrap; justify-content: space-evenly; width: 50%'>"
90
+ html += "<div style='width: 100%; text-align: center;'><b>CLIP</b> (<a href='https://arxiv.org/abs/2103.00020'>Arxiv</a>, <a href='https://github.com/openai/CLIP'>GitHub</a>) from OpenAI</div>"
91
+ for url, title, link in url_list:
92
+ html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
93
+ if len(link) > 0:
94
+ html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
95
+ html = html + html2
96
+ html += "</span>"
97
+ html += "<span style='margin-top: 20px; max-width: 1200px; display: flex; align-content: flex-start; flex-wrap: wrap; justify-content: space-evenly; width: 50%; border-left: solid; border-color: #ffc423; border-width: thin;'>"
98
+ html += "<div style='width: 100%; text-align: center;'><b>SLIP</b> (<a href='https://arxiv.org/abs/2112.12750'>Arxiv</a>, <a href='https://github.com/facebookresearch/SLIP'>GitHub</a>) from Meta AI</div>"
99
+ for url, title, link in url_list_slip:
100
+ html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
101
+ if len(link) > 0:
102
+ html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
103
+ html = html + html2
104
+ html += "</span></div>"
105
+ return html
106
+
107
+ def compute_text_embeddings(list_of_strings):
108
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
109
+ return model.get_text_features(**inputs)
110
+
111
+ def compute_text_embeddings_slip(list_of_strings):
112
+ texts = tokenizer(list_of_strings)
113
+ if cuda_available:
114
+ texts = texts.cuda(non_blocking=True)
115
+ texts = texts.view(-1, 77).contiguous()
116
+ return slip_model.encode_text(texts)
117
+
118
+ def image_search(query, corpus, n_results=24):
119
+ text_embeddings = compute_text_embeddings([query]).detach().numpy()
120
+ text_embeddings_slip = compute_text_embeddings_slip([query]).detach().numpy()
121
+ k = 0 if corpus == "Unsplash" else 1
122
+ results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
123
+ -1 : -n_results - 1 : -1
124
+ ]
125
+ results_slip = np.argsort((slip_embeddings[k] @ text_embeddings_slip.T)[:, 0])[
126
+ -1 : -n_results - 1 : -1
127
+ ]
128
+ return (
129
+ [
130
+ (
131
+ df[k].iloc[i]["path"],
132
+ df[k].iloc[i]["tooltip"] + source[k],
133
+ df[k].iloc[i]["link"],
134
+ )
135
+ for i in results
136
+ ],
137
+ [
138
+ (
139
+ df[k].iloc[i]["path"],
140
+ df[k].iloc[i]["tooltip"] + source[k],
141
+ df[k].iloc[i]["link"],
142
+ )
143
+ for i in results_slip
144
+ ],
145
+ )
146
+
147
+
148
+ description = """
149
+ # Comparing CLIP and SLIP side by side
150
+
151
+ **Enter your query and hit enter**
152
+
153
+ CLIP and SLIP are ML models that encode images and texts as vectors so that the vectors of an image and its caption are similar. They can notably be used for zero-shot image classification, text-based image retrieval or image generation.
154
+
155
+ *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, Meta AI's [SLIP](https://github.com/facebookresearch/SLIP) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
156
+ """
157
+
158
+
159
+
160
+ st.markdown(
161
+ """
162
+ <style>
163
+ .block-container{
164
+ max-width: 1200px;
165
+ }
166
+ div.row-widget.stRadio > div{
167
+ flex-direction:row;
168
+ display: flex;
169
+ justify-content: center;
170
+ }
171
+ div.row-widget.stRadio > div > label{
172
+ margin-left: 5px;
173
+ margin-right: 5px;
174
+ }
175
+ section.main>div:first-child {
176
+ padding-top: 0px;
177
+ }
178
+ section:not(.main)>div:first-child {
179
+ padding-top: 30px;
180
+ }
181
+ div.reportview-container > section:first-child{
182
+ max-width: 320px;
183
+ }
184
+ #MainMenu {
185
+ visibility: hidden;
186
+ }
187
+ footer {
188
+ visibility: hidden;
189
+ }
190
+ </style>""",
191
+ unsafe_allow_html=True,
192
+ )
193
+ st.sidebar.markdown(description)
194
+ _, c, _ = st.columns((1, 3, 1))
195
+ query = c.text_input("", value="clouds at sunset")
196
+ corpus = st.radio("", ["Unsplash", "Movies"])
197
+ if len(query) > 0:
198
+ results, results_slip = image_search(query, corpus)
199
+ st.markdown(get_html(results, results_slip), unsafe_allow_html=True)
bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
data.csv ADDED
The diff for this file is too large to render. See raw diff
 
data2.csv ADDED
The diff for this file is too large to render. See raw diff
 
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f8c171e32276739be6b020592edc8a2c06e029ff6505a9d1d4efe3cafa073bd
3
+ size 51200128
embeddings2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9664e980f31e81c4a34e07833539fea32795d83a4262c9828ceae445fa2e412a
3
+ size 16732288
embeddings2_slip_large.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5632813e4a27062f2a7bc3f2db23ac3f62d946b53d3b9144c1d5c7e8f9865f90
3
+ size 16732288
embeddings_slip_large.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98fd7411e6874bfd703c134470b9e5a82c0a7a403bb1cf1cac5851dc3871498f
3
+ size 51200128
losses.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ import utils
11
+
12
+
13
+ class CLIPLoss(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.labels = None
17
+ self.last_local_batch_size = None
18
+
19
+ def forward(self, outputs):
20
+ image_embed = outputs['image_embed']
21
+ text_embed = outputs['text_embed']
22
+ logit_scale = outputs['logit_scale']
23
+ local_batch_size = image_embed.size(0)
24
+
25
+ if local_batch_size != self.last_local_batch_size:
26
+ self.labels = local_batch_size * utils.get_rank() + torch.arange(
27
+ local_batch_size, device=image_embed.device
28
+ )
29
+ self.last_local_batch_size = local_batch_size
30
+
31
+ # normalized features
32
+ image_embed = F.normalize(image_embed, dim=-1, p=2)
33
+ text_embed = F.normalize(text_embed, dim=-1, p=2)
34
+
35
+ # gather features from all GPUs
36
+ image_embed_all, text_embed_all = \
37
+ utils.all_gather_batch([image_embed, text_embed])
38
+
39
+ # cosine similarity as logits
40
+ logits_per_image = logit_scale * image_embed @ text_embed_all.t()
41
+ logits_per_text = logit_scale * text_embed @ image_embed_all.t()
42
+
43
+ loss = (F.cross_entropy(logits_per_image, self.labels) + \
44
+ F.cross_entropy(logits_per_text, self.labels)) / 2
45
+
46
+ # compute accuracy
47
+ with torch.no_grad():
48
+ pred = torch.argmax(logits_per_image, dim=-1)
49
+ correct = pred.eq(self.labels).sum()
50
+ acc = 100 * correct / local_batch_size
51
+
52
+ return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}
53
+
54
+
55
+ class SIMCLRLoss(nn.Module):
56
+ """
57
+ This is the SimCLR loss in https://arxiv.org/abs/2002.05709
58
+ The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
59
+ the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
60
+ This memory layout is consistent with the SimCLR collator in
61
+ https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
62
+ Config params:
63
+ temperature (float): the temperature to be applied on the logits
64
+ """
65
+
66
+ def __init__(self, temperature=0.1):
67
+ super().__init__()
68
+ self.tau = temperature
69
+ self.labels = None
70
+ self.masks = None
71
+ self.last_local_batch_size = None
72
+
73
+ def forward(self, outputs):
74
+ q_a = outputs['aug1_embed']
75
+ q_b = outputs['aug2_embed']
76
+
77
+ q_a = F.normalize(q_a, dim=-1, p=2)
78
+ q_b = F.normalize(q_b, dim=-1, p=2)
79
+
80
+ local_batch_size = q_a.size(0)
81
+
82
+ k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b])
83
+
84
+ if local_batch_size != self.last_local_batch_size:
85
+ self.labels = local_batch_size * utils.get_rank() + torch.arange(
86
+ local_batch_size, device=q_a.device
87
+ )
88
+ total_batch_size = local_batch_size * utils.get_world_size()
89
+ self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
90
+ self.last_local_batch_size = local_batch_size
91
+
92
+ logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
93
+ logits_aa = logits_aa - self.masks
94
+ logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
95
+ logits_bb = logits_bb - self.masks
96
+ logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
97
+ logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
98
+
99
+ loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels)
100
+ loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels)
101
+ loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
102
+
103
+ # compute accuracy
104
+ with torch.no_grad():
105
+ pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1)
106
+ correct = pred.eq(self.labels).sum()
107
+ acc = 100 * correct / local_batch_size
108
+
109
+ return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc}
110
+
111
+
112
+ class SLIPLoss(nn.Module):
113
+ def __init__(self, ssl_loss, ssl_scale):
114
+ super().__init__()
115
+ self.clip_loss = CLIPLoss()
116
+ self.ssl_loss = ssl_loss
117
+ self.ssl_scale = ssl_scale
118
+
119
+ def forward(self, outputs):
120
+ clip_loss_dict = self.clip_loss(outputs)
121
+ clip_loss = clip_loss_dict['clip_loss']
122
+ clip_acc = clip_loss_dict['clip_acc']
123
+
124
+ ssl_loss_dict = self.ssl_loss(outputs)
125
+ ssl_loss = ssl_loss_dict['ssl_loss']
126
+ ssl_acc = ssl_loss_dict['ssl_acc']
127
+
128
+ return {'loss': clip_loss + self.ssl_scale * ssl_loss,
129
+ 'clip_loss': clip_loss,
130
+ 'clip_acc': clip_acc,
131
+ 'ssl_loss': ssl_loss,
132
+ 'ssl_acc': ssl_acc}
models.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from github.com/openai/CLIP
8
+ from collections import OrderedDict
9
+
10
+ import numpy as np
11
+ import timm
12
+ import torch
13
+ from torch import nn
14
+
15
+ import losses
16
+
17
+
18
+ class LayerNorm(nn.LayerNorm):
19
+ """Subclass torch's LayerNorm to handle fp16."""
20
+
21
+ def forward(self, x: torch.Tensor):
22
+ orig_type = x.dtype
23
+ ret = super().forward(x.type(torch.float32))
24
+ return ret.type(orig_type)
25
+
26
+
27
+ class QuickGELU(nn.Module):
28
+ def forward(self, x: torch.Tensor):
29
+ return x * torch.sigmoid(1.702 * x)
30
+
31
+
32
+ class ResidualAttentionBlock(nn.Module):
33
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
34
+ super().__init__()
35
+
36
+ self.attn = nn.MultiheadAttention(d_model, n_head)
37
+ self.ln_1 = LayerNorm(d_model)
38
+ self.mlp = nn.Sequential(OrderedDict([
39
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
40
+ ("gelu", QuickGELU()),
41
+ ("c_proj", nn.Linear(d_model * 4, d_model))
42
+ ]))
43
+ self.ln_2 = LayerNorm(d_model)
44
+ self.attn_mask = attn_mask
45
+
46
+ def attention(self, x: torch.Tensor):
47
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
48
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
49
+
50
+ def forward(self, x: torch.Tensor):
51
+ x = x + self.attention(self.ln_1(x))
52
+ x = x + self.mlp(self.ln_2(x))
53
+ return x
54
+
55
+
56
+ class Transformer(nn.Module):
57
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
58
+ super().__init__()
59
+ self.width = width
60
+ self.layers = layers
61
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
62
+
63
+ def forward(self, x: torch.Tensor):
64
+ return self.resblocks(x)
65
+
66
+
67
+ class CLIP(nn.Module):
68
+ def __init__(self,
69
+ embed_dim: int,
70
+ # vision
71
+ vision_width: int,
72
+ vision_model: nn.Module,
73
+ # text
74
+ context_length: int,
75
+ vocab_size: int,
76
+ transformer_width: int,
77
+ transformer_heads: int,
78
+ transformer_layers: int,
79
+ **kwargs,
80
+ ):
81
+ super().__init__()
82
+
83
+ self.context_length = context_length
84
+ self.vision_width = vision_width
85
+
86
+ self.visual = vision_model
87
+
88
+ self.transformer = Transformer(
89
+ width=transformer_width,
90
+ layers=transformer_layers,
91
+ heads=transformer_heads,
92
+ attn_mask=self.build_attention_mask(),
93
+ )
94
+
95
+ self.vocab_size = vocab_size
96
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
97
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
98
+ self.ln_final = LayerNorm(transformer_width)
99
+
100
+ self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
101
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
102
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
103
+
104
+ self.initialize_parameters()
105
+
106
+ def initialize_parameters(self):
107
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
108
+ nn.init.normal_(self.positional_embedding, std=0.01)
109
+
110
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
111
+ attn_std = self.transformer.width ** -0.5
112
+ fc_std = (2 * self.transformer.width) ** -0.5
113
+ for block in self.transformer.resblocks:
114
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
115
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
116
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
117
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
118
+
119
+ nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
120
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
121
+
122
+ def build_attention_mask(self):
123
+ # lazily create causal attention mask, with full attention between the vision tokens
124
+ # pytorch uses additive attention mask; fill with -inf
125
+ mask = torch.empty(self.context_length, self.context_length)
126
+ mask.fill_(float("-inf"))
127
+ mask.triu_(1) # zero out the lower diagonal
128
+ return mask
129
+
130
+ def encode_image(self, image):
131
+ x = self.visual(image)
132
+ x = x @ self.image_projection
133
+
134
+ return x
135
+
136
+ def encode_text(self, text):
137
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
138
+ x = x + self.positional_embedding
139
+ x = x.permute(1, 0, 2) # NLD -> LND
140
+ x = self.transformer(x)
141
+ x = x.permute(1, 0, 2) # LND -> NLD
142
+ x = self.ln_final(x)
143
+
144
+ # x.shape = [batch_size, n_ctx, transformer.width]
145
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
146
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
147
+
148
+ return x
149
+
150
+ def forward(self, image, text):
151
+ image_embed = self.encode_image(image)
152
+ text_embed = self.encode_text(text)
153
+
154
+ return {'image_embed': image_embed,
155
+ 'text_embed': text_embed,
156
+ 'logit_scale': self.logit_scale.exp()}
157
+
158
+
159
+ class SIMCLR(nn.Module):
160
+ def __init__(self,
161
+ # vision
162
+ vision_width: int,
163
+ vision_model: nn.Module,
164
+ # ssl
165
+ ssl_mlp_dim: int,
166
+ ssl_emb_dim: int,
167
+ **kwargs,
168
+ ):
169
+ super().__init__()
170
+
171
+ self.vision_width = vision_width
172
+ self.visual = vision_model
173
+
174
+ self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim)
175
+
176
+ def _build_mlp(self, in_dim, mlp_dim, out_dim):
177
+ return nn.Sequential(OrderedDict([
178
+ ("layer1", nn.Linear(in_dim, mlp_dim)),
179
+ ("bn1", nn.SyncBatchNorm(mlp_dim)),
180
+ ("relu1", nn.ReLU(inplace=True)),
181
+ ("layer2", nn.Linear(mlp_dim, mlp_dim)),
182
+ ("bn2", nn.SyncBatchNorm(mlp_dim)),
183
+ ("relu2", nn.ReLU(inplace=True)),
184
+ ("layer3", nn.Linear(mlp_dim, out_dim)),
185
+ ]))
186
+
187
+ def encode_image(self, image):
188
+ x = self.visual(image)
189
+
190
+ return x
191
+
192
+ def forward(self, aug1, aug2):
193
+ h1 = self.visual(aug1)
194
+ h2 = self.visual(aug2)
195
+
196
+ aug1_embed = self.image_mlp(h1)
197
+ aug2_embed = self.image_mlp(h2)
198
+
199
+ return {'aug1_embed': aug1_embed,
200
+ 'aug2_embed': aug2_embed}
201
+
202
+
203
+ class SLIP(CLIP):
204
+ def __init__(self,
205
+ ssl_mlp_dim: int,
206
+ ssl_emb_dim: int,
207
+ **kwargs,
208
+ ):
209
+ super().__init__(**kwargs)
210
+
211
+ self.image_mlp = self._build_mlp(in_dim=self.vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim)
212
+
213
+ def _build_mlp(self, in_dim, mlp_dim, out_dim):
214
+ return nn.Sequential(OrderedDict([
215
+ ("layer1", nn.Linear(in_dim, mlp_dim)),
216
+ ("bn1", nn.SyncBatchNorm(mlp_dim)),
217
+ ("relu1", nn.ReLU(inplace=True)),
218
+ ("layer2", nn.Linear(mlp_dim, mlp_dim)),
219
+ ("bn2", nn.SyncBatchNorm(mlp_dim)),
220
+ ("relu2", nn.ReLU(inplace=True)),
221
+ ("layer3", nn.Linear(mlp_dim, out_dim)),
222
+ ]))
223
+
224
+ def forward(self, image, text, aug1, aug2):
225
+ aug1_embed = self.image_mlp(self.visual(aug1))
226
+ aug2_embed = self.image_mlp(self.visual(aug2))
227
+
228
+ image_embed = self.encode_image(image)
229
+ text_embed = self.encode_text(text)
230
+
231
+ return {'image_embed': image_embed,
232
+ 'text_embed': text_embed,
233
+ 'logit_scale': self.logit_scale.exp(),
234
+ 'aug1_embed': aug1_embed,
235
+ 'aug2_embed': aug2_embed}
236
+
237
+
238
+ def get_loss(model, ssl_temp, ssl_scale):
239
+ if model.startswith('SLIP'):
240
+ ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp)
241
+ return losses.SLIPLoss(ssl_loss, ssl_scale)
242
+ if model.startswith('CLIP'):
243
+ return losses.CLIPLoss()
244
+ if model.startswith('SIMCLR'):
245
+ return losses.SIMCLRLoss(temperature=ssl_temp)
246
+
247
+
248
+ def get_metric_names(model):
249
+ if model.startswith('SLIP'):
250
+ return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc']
251
+ elif model.startswith('CLIP'):
252
+ return ['loss', 'clip_loss', 'clip_acc']
253
+ else:
254
+ return ['loss', 'ssl_loss', 'ssl_acc']
255
+
256
+
257
+ @timm.models.registry.register_model
258
+ def vit_small_mocov3_patch16_224(**kwargs):
259
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs)
260
+ model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs)
261
+
262
+ return model
263
+
264
+
265
+ def CLIP_VITS16(**kwargs):
266
+ vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0)
267
+ model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408,
268
+ transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
269
+
270
+ return model
271
+
272
+
273
+ def SIMCLR_VITS16(**kwargs):
274
+ vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0)
275
+ model = SIMCLR(vision_width=384, vision_model=vision_model, **kwargs)
276
+
277
+ return model
278
+
279
+
280
+ def SLIP_VITS16(**kwargs):
281
+ vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0)
282
+ model = SLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408,
283
+ transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
284
+
285
+ return model
286
+
287
+
288
+ def CLIP_VITB16(**kwargs):
289
+ vision_model = timm.create_model('vit_base_patch16_224', num_classes=0)
290
+ model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408,
291
+ transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
292
+
293
+ return model
294
+
295
+
296
+ def SIMCLR_VITB16(**kwargs):
297
+ vision_model = timm.create_model('vit_base_patch16_224', num_classes=0)
298
+ model = SIMCLR(vision_width=768, vision_model=vision_model, **kwargs)
299
+
300
+ return model
301
+
302
+
303
+ def SLIP_VITB16(**kwargs):
304
+ vision_model = timm.create_model('vit_base_patch16_224', num_classes=0)
305
+ model = SLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408,
306
+ transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
307
+
308
+ return model
309
+
310
+
311
+ def CLIP_VITL16(**kwargs):
312
+ vision_model = timm.create_model('vit_large_patch16_224', num_classes=0)
313
+ model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408,
314
+ transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
315
+
316
+ return model
317
+
318
+
319
+ def SIMCLR_VITL16(**kwargs):
320
+ vision_model = timm.create_model('vit_large_patch16_224', num_classes=0)
321
+ model = SIMCLR(vision_width=1024, vision_model=vision_model, **kwargs)
322
+
323
+ return model
324
+
325
+
326
+ def SLIP_VITL16(**kwargs):
327
+ vision_model = timm.create_model('vit_large_patch16_224', num_classes=0)
328
+ model = SLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408,
329
+ transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
330
+
331
+ return model
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torchvision
2
+ transformers
3
+ numpy
4
+ pandas
5
+ timm
6
+ ftfy
tokenizer.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from github.com/openai/CLIP
8
+ import gzip
9
+ import html
10
+ import os
11
+ from functools import lru_cache
12
+
13
+ import ftfy
14
+ import regex as re
15
+ import torch
16
+
17
+
18
+ @lru_cache()
19
+ def default_bpe():
20
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
21
+
22
+
23
+ @lru_cache()
24
+ def bytes_to_unicode():
25
+ """
26
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
27
+ The reversible bpe codes work on unicode strings.
28
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
31
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
33
+ """
34
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
35
+ cs = bs[:]
36
+ n = 0
37
+ for b in range(2**8):
38
+ if b not in bs:
39
+ bs.append(b)
40
+ cs.append(2**8+n)
41
+ n += 1
42
+ cs = [chr(n) for n in cs]
43
+ return dict(zip(bs, cs))
44
+
45
+
46
+ def get_pairs(word):
47
+ """Return set of symbol pairs in a word.
48
+ Word is represented as tuple of symbols (symbols being variable-length strings).
49
+ """
50
+ pairs = set()
51
+ prev_char = word[0]
52
+ for char in word[1:]:
53
+ pairs.add((prev_char, char))
54
+ prev_char = char
55
+ return pairs
56
+
57
+
58
+ def basic_clean(text):
59
+ text = ftfy.fix_text(text)
60
+ text = html.unescape(html.unescape(text))
61
+ return text.strip()
62
+
63
+
64
+ def whitespace_clean(text):
65
+ text = re.sub(r'\s+', ' ', text)
66
+ text = text.strip()
67
+ return text
68
+
69
+
70
+ class SimpleTokenizer(object):
71
+ def __init__(self, bpe_path: str = default_bpe()):
72
+ self.byte_encoder = bytes_to_unicode()
73
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
74
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
75
+ merges = merges[1:49152-256-2+1]
76
+ merges = [tuple(merge.split()) for merge in merges]
77
+ vocab = list(bytes_to_unicode().values())
78
+ vocab = vocab + [v+'</w>' for v in vocab]
79
+ for merge in merges:
80
+ vocab.append(''.join(merge))
81
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
82
+ self.encoder = dict(zip(vocab, range(len(vocab))))
83
+ self.decoder = {v: k for k, v in self.encoder.items()}
84
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
85
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
86
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
87
+
88
+ def bpe(self, token):
89
+ if token in self.cache:
90
+ return self.cache[token]
91
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
92
+ pairs = get_pairs(word)
93
+
94
+ if not pairs:
95
+ return token+'</w>'
96
+
97
+ while True:
98
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
99
+ if bigram not in self.bpe_ranks:
100
+ break
101
+ first, second = bigram
102
+ new_word = []
103
+ i = 0
104
+ while i < len(word):
105
+ try:
106
+ j = word.index(first, i)
107
+ new_word.extend(word[i:j])
108
+ i = j
109
+ except:
110
+ new_word.extend(word[i:])
111
+ break
112
+
113
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
114
+ new_word.append(first+second)
115
+ i += 2
116
+ else:
117
+ new_word.append(word[i])
118
+ i += 1
119
+ new_word = tuple(new_word)
120
+ word = new_word
121
+ if len(word) == 1:
122
+ break
123
+ else:
124
+ pairs = get_pairs(word)
125
+ word = ' '.join(word)
126
+ self.cache[token] = word
127
+ return word
128
+
129
+ def encode(self, text):
130
+ bpe_tokens = []
131
+ text = whitespace_clean(basic_clean(text)).lower()
132
+ for token in re.findall(self.pat, text):
133
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
134
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
135
+ return bpe_tokens
136
+
137
+ def decode(self, tokens):
138
+ text = ''.join([self.decoder[token] for token in tokens])
139
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
140
+ return text
141
+
142
+ def __call__(self, texts, context_length=77):
143
+ if isinstance(texts, str):
144
+ texts = [texts]
145
+
146
+ sot_token = self.encoder["<|startoftext|>"]
147
+ eot_token = self.encoder["<|endoftext|>"]
148
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
149
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
150
+
151
+ for i, tokens in enumerate(all_tokens):
152
+ tokens = tokens[:context_length]
153
+ result[i, :len(tokens)] = torch.tensor(tokens)
154
+
155
+ if len(result) == 1:
156
+ return result[0]
157
+ return result
utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import numpy as np
7
+ import os
8
+ import random
9
+ import shutil
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.autograd as autograd
13
+
14
+ from PIL import ImageFilter
15
+
16
+
17
+ def get_model(model):
18
+ if isinstance(model, torch.nn.DataParallel) \
19
+ or isinstance(model, torch.nn.parallel.DistributedDataParallel):
20
+ return model.module
21
+ else:
22
+ return model
23
+
24
+
25
+ def setup_for_distributed(is_master):
26
+ """
27
+ This function disables printing when not in master process
28
+ """
29
+ import builtins as __builtin__
30
+ builtin_print = __builtin__.print
31
+
32
+ def print(*args, **kwargs):
33
+ force = kwargs.pop('force', False)
34
+ if is_master or force:
35
+ builtin_print(*args, **kwargs)
36
+
37
+ __builtin__.print = print
38
+
39
+
40
+ def is_dist_avail_and_initialized():
41
+ if not dist.is_available():
42
+ return False
43
+ if not dist.is_initialized():
44
+ return False
45
+ return True
46
+
47
+
48
+ def get_world_size():
49
+ if not is_dist_avail_and_initialized():
50
+ return 1
51
+ return dist.get_world_size()
52
+
53
+
54
+ def get_rank():
55
+ if not is_dist_avail_and_initialized():
56
+ return 0
57
+ return dist.get_rank()
58
+
59
+
60
+ def is_main_process():
61
+ return get_rank() == 0
62
+
63
+
64
+ def save_on_master(state, is_best, output_dir):
65
+ if is_main_process():
66
+ ckpt_path = f'{output_dir}/checkpoint.pt'
67
+ best_path = f'{output_dir}/checkpoint_best.pt'
68
+ torch.save(state, ckpt_path)
69
+ if is_best:
70
+ shutil.copyfile(ckpt_path, best_path)
71
+
72
+
73
+ def init_distributed_mode(args):
74
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
75
+ args.rank = int(os.environ["RANK"])
76
+ args.world_size = int(os.environ['WORLD_SIZE'])
77
+ args.gpu = int(os.environ['LOCAL_RANK'])
78
+ elif 'SLURM_PROCID' in os.environ:
79
+ args.rank = int(os.environ['SLURM_PROCID'])
80
+ args.gpu = args.rank % torch.cuda.device_count()
81
+ else:
82
+ print('Not using distributed mode')
83
+ args.distributed = False
84
+ return
85
+
86
+ args.distributed = True
87
+
88
+ torch.cuda.set_device(args.gpu)
89
+ args.dist_backend = 'nccl'
90
+ print('| distributed init (rank {}): {}'.format(
91
+ args.rank, args.dist_url), flush=True)
92
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
93
+ world_size=args.world_size, rank=args.rank)
94
+ torch.distributed.barrier()
95
+ setup_for_distributed(args.rank == 0)
96
+
97
+
98
+ def scaled_all_reduce(tensors, is_scale=True):
99
+ """Performs the scaled all_reduce operation on the provided tensors.
100
+ The input tensors are modified in-place. Currently supports only the sum
101
+ reduction operator. The reduced values are scaled by the inverse size of the
102
+ world size.
103
+ """
104
+ world_size = get_world_size()
105
+ # There is no need for reduction in the single-proc case
106
+ if world_size == 1:
107
+ return tensors
108
+ # Queue the reductions
109
+ reductions = []
110
+ for tensor in tensors:
111
+ reduction = dist.all_reduce(tensor, async_op=True)
112
+ reductions.append(reduction)
113
+ # Wait for reductions to finish
114
+ for reduction in reductions:
115
+ reduction.wait()
116
+ # Scale the results
117
+ if is_scale:
118
+ for tensor in tensors:
119
+ tensor.mul_(1.0 / world_size)
120
+ return tensors
121
+
122
+
123
+ def all_gather_batch(tensors):
124
+ """
125
+ Performs all_gather operation on the provided tensors.
126
+ """
127
+ # Queue the gathered tensors
128
+ world_size = get_world_size()
129
+ # There is no need for reduction in the single-proc case
130
+ if world_size == 1:
131
+ return tensors
132
+ tensor_list = []
133
+ output_tensor = []
134
+ for tensor in tensors:
135
+ tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
136
+ dist.all_gather(
137
+ tensor_all,
138
+ tensor,
139
+ async_op=False # performance opt
140
+ )
141
+
142
+ tensor_list.append(tensor_all)
143
+
144
+ for tensor_all in tensor_list:
145
+ output_tensor.append(torch.cat(tensor_all, dim=0))
146
+ return output_tensor
147
+
148
+
149
+ class GatherLayer(autograd.Function):
150
+ """
151
+ Gather tensors from all workers with support for backward propagation:
152
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
153
+ """
154
+
155
+ @staticmethod
156
+ def forward(ctx, x):
157
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
158
+ dist.all_gather(output, x)
159
+ return tuple(output)
160
+
161
+ @staticmethod
162
+ def backward(ctx, *grads):
163
+ all_gradients = torch.stack(grads)
164
+ dist.all_reduce(all_gradients)
165
+ return all_gradients[dist.get_rank()]
166
+
167
+
168
+ def all_gather_batch_with_grad(tensors):
169
+ """
170
+ Performs all_gather operation on the provided tensors.
171
+ Graph remains connected for backward grad computation.
172
+ """
173
+ # Queue the gathered tensors
174
+ world_size = get_world_size()
175
+ # There is no need for reduction in the single-proc case
176
+ if world_size == 1:
177
+ return tensors
178
+ tensor_list = []
179
+ output_tensor = []
180
+
181
+ for tensor in tensors:
182
+ tensor_all = GatherLayer.apply(tensor)
183
+ tensor_list.append(tensor_all)
184
+
185
+ for tensor_all in tensor_list:
186
+ output_tensor.append(torch.cat(tensor_all, dim=0))
187
+ return output_tensor
188
+
189
+
190
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
191
+ warmup_schedule = np.array([])
192
+ warmup_iters = warmup_epochs * niter_per_ep
193
+ if warmup_epochs > 0:
194
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
195
+
196
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
197
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
198
+
199
+ schedule = np.concatenate((warmup_schedule, schedule))
200
+ assert len(schedule) == epochs * niter_per_ep
201
+ return schedule
202
+
203
+
204
+ class GaussianBlur(object):
205
+ """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
206
+
207
+ def __init__(self, sigma=[.1, 2.]):
208
+ self.sigma = sigma
209
+
210
+ def __call__(self, x):
211
+ sigma = random.uniform(self.sigma[0], self.sigma[1])
212
+ x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
213
+ return x