Files changed (9) hide show
  1. .gitignore +0 -70
  2. README.md +4 -21
  3. configuration_clip.py +6 -22
  4. eva_model.py +27 -30
  5. hf_model.py +102 -169
  6. modeling_clip.py +160 -241
  7. processing_clip.py +1 -0
  8. rope_embeddings.py +9 -4
  9. transform.py +179 -95
.gitignore DELETED
@@ -1,70 +0,0 @@
1
- # Project specific
2
- __init__.py
3
- pyproject.toml
4
-
5
- # Byte-compiled / optimized / DLL files
6
- __pycache__/
7
- *.py[cod]
8
- *$py.class
9
-
10
- # C extensions
11
- *.so
12
-
13
- # Distribution / packaging
14
- .Python
15
- build/
16
- develop-eggs/
17
- dist/
18
- downloads/
19
- eggs/
20
- .eggs/
21
- lib/
22
- lib64/
23
- parts/
24
- sdist/
25
- var/
26
- wheels/
27
- pip-wheel-metadata/
28
- share/python-wheels/
29
- *.egg-info/
30
- .installed.cfg
31
- *.egg
32
- MANIFEST
33
-
34
- # Unit test / coverage reports
35
- htmlcov/
36
- .tox/
37
- .nox/
38
- .coverage
39
- .coverage.*
40
- .cache
41
- nosetests.xml
42
- coverage.xml
43
- *.cover
44
- *.py,cover
45
- .hypothesis/
46
- .pytest_cache/
47
-
48
- # Jupyter Notebook
49
- .ipynb_checkpoints
50
-
51
- # IPython
52
- profile_default/
53
- ipython_config.py
54
-
55
- # Environments
56
- .env
57
- .venv
58
- env/
59
- venv/
60
- ENV/
61
- env.bak/
62
- venv.bak/
63
-
64
- # mypy
65
- .mypy_cache/
66
- .dmypy.json
67
- dmypy.json
68
-
69
- # PyCharm
70
- .idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,27 +1,10 @@
1
- ---
2
- tags:
3
- - transformers
4
- - xlm-roberta
5
- - eva02
6
- - clip
7
- library_name: transformers
8
- license: cc-by-nc-4.0
9
- ---
10
-
11
  # Jina CLIP
12
 
13
- Core implementation of Jina CLIP. The model uses:
14
- * the [EVA 02](https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip) architecture for the vision tower
15
- * the [Jina XLM RoBERTa with Flash Attention](https://huggingface.co/jinaai/xlm-roberta-flash-implementation) model as a text tower
16
-
17
- ## Models that use this implementation
18
-
19
- - [jinaai/jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2)
20
- - [jinaai/jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1)
21
-
22
- ## Requirements
23
 
24
- To use the Jina CLIP source code, the following packages are required:
25
  * `torch`
26
  * `timm`
27
  * `transformers`
 
 
 
 
 
 
 
 
 
 
 
1
  # Jina CLIP
2
 
3
+ The Jina CLIP implementation is hosted in this repository. The model uses:
4
+ * the EVA 02 architecture for the vision tower
5
+ * the Jina BERT with Flash Attention model as a text tower
 
 
 
 
 
 
 
6
 
7
+ To use the Jina CLIP model, the following packages are required:
8
  * `torch`
9
  * `timm`
10
  * `transformers`
configuration_clip.py CHANGED
@@ -8,7 +8,6 @@ import os
8
  from copy import deepcopy
9
  from typing import Any, Dict, List, Optional, Union
10
 
11
- import torch
12
  from transformers import PretrainedConfig, logging
13
 
14
  logger = logging.get_logger(__name__)
@@ -25,8 +24,6 @@ class JinaCLIPTextConfig(PretrainedConfig):
25
  embed_dim: int = 768,
26
  hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
27
  hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
28
- default_instruction_task: Optional[str] = None,
29
- default_lora_task: Optional[str] = None,
30
  pooler_type: Optional[str] = None,
31
  proj_type: Optional[str] = None,
32
  proj_bias: bool = False,
@@ -37,8 +34,6 @@ class JinaCLIPTextConfig(PretrainedConfig):
37
  self.embed_dim = embed_dim
38
  self.hf_model_name_or_path = hf_model_name_or_path
39
  self.hf_model_config_kwargs = hf_model_config_kwargs or {}
40
- self.default_instruction_task = default_instruction_task
41
- self.default_lora_task = default_lora_task
42
  self.pooler_type = pooler_type
43
  self.proj_type = proj_type
44
  self.proj_bias = proj_bias
@@ -52,9 +47,11 @@ class JinaCLIPTextConfig(PretrainedConfig):
52
  configdict, kwargs = cls.get_config_dict(
53
  pretrained_model_name_or_path, **kwargs
54
  )
 
55
  # get the text config dict if we are loading from JinaCLIPConfig
56
  if configdict.get('model_type') == 'jina_clip':
57
  configdict = configdict['text_config']
 
58
  if (
59
  'model_type' in configdict
60
  and hasattr(cls, 'model_type')
@@ -65,6 +62,7 @@ class JinaCLIPTextConfig(PretrainedConfig):
65
  f'instantiate a model of type {cls.model_type}. This is not supported '
66
  'for all configurations of models and can yield errors.'
67
  )
 
68
  return cls.from_dict(configdict, **kwargs)
69
 
70
 
@@ -127,9 +125,11 @@ class JinaCLIPVisionConfig(PretrainedConfig):
127
  configdict, kwargs = cls.get_config_dict(
128
  pretrained_model_name_or_path, **kwargs
129
  )
 
130
  # get the vision config dict if we are loading from JinaCLIPConfig
131
  if configdict.get('model_type') == 'jina_clip':
132
  configdict = configdict['vision_config']
 
133
  if (
134
  'model_type' in configdict
135
  and hasattr(cls, 'model_type')
@@ -140,6 +140,7 @@ class JinaCLIPVisionConfig(PretrainedConfig):
140
  f'instantiate a model of type {cls.model_type}. This is not supported '
141
  'for all configurations of models and can yield errors.'
142
  )
 
143
  return cls.from_dict(configdict, **kwargs)
144
 
145
 
@@ -158,7 +159,6 @@ class JinaCLIPConfig(PretrainedConfig):
158
  use_vision_xformers: Optional[bool] = None,
159
  matryoshka_dimensions: Optional[List[int]] = None,
160
  truncate_dim: Optional[int] = None,
161
- torch_dtype: Optional[Union[str, torch.dtype]] = None,
162
  **kwargs,
163
  ):
164
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -286,22 +286,6 @@ class JinaCLIPConfig(PretrainedConfig):
286
  'projections with `add_projections=True`.'
287
  )
288
 
289
- if (
290
- torch_dtype
291
- and hasattr(torch, torch_dtype)
292
- and type(getattr(torch, torch_dtype)) is torch.dtype
293
- ):
294
- self.torch_dtype = getattr(torch, torch_dtype)
295
- else:
296
- self.torch_dtype = torch_dtype
297
-
298
- use_text_flash_attn = (
299
- self.use_text_flash_attn if self.use_text_flash_attn is not None
300
- else self.text_config.hf_model_config_kwargs.get('use_flash_attn', False)
301
- )
302
- if not use_text_flash_attn or not torch.cuda.is_available():
303
- self.torch_dtype = torch.float32
304
-
305
  @classmethod
306
  def from_text_vision_configs(
307
  cls,
 
8
  from copy import deepcopy
9
  from typing import Any, Dict, List, Optional, Union
10
 
 
11
  from transformers import PretrainedConfig, logging
12
 
13
  logger = logging.get_logger(__name__)
 
24
  embed_dim: int = 768,
25
  hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
26
  hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
 
 
27
  pooler_type: Optional[str] = None,
28
  proj_type: Optional[str] = None,
29
  proj_bias: bool = False,
 
34
  self.embed_dim = embed_dim
35
  self.hf_model_name_or_path = hf_model_name_or_path
36
  self.hf_model_config_kwargs = hf_model_config_kwargs or {}
 
 
37
  self.pooler_type = pooler_type
38
  self.proj_type = proj_type
39
  self.proj_bias = proj_bias
 
47
  configdict, kwargs = cls.get_config_dict(
48
  pretrained_model_name_or_path, **kwargs
49
  )
50
+
51
  # get the text config dict if we are loading from JinaCLIPConfig
52
  if configdict.get('model_type') == 'jina_clip':
53
  configdict = configdict['text_config']
54
+
55
  if (
56
  'model_type' in configdict
57
  and hasattr(cls, 'model_type')
 
62
  f'instantiate a model of type {cls.model_type}. This is not supported '
63
  'for all configurations of models and can yield errors.'
64
  )
65
+
66
  return cls.from_dict(configdict, **kwargs)
67
 
68
 
 
125
  configdict, kwargs = cls.get_config_dict(
126
  pretrained_model_name_or_path, **kwargs
127
  )
128
+
129
  # get the vision config dict if we are loading from JinaCLIPConfig
130
  if configdict.get('model_type') == 'jina_clip':
131
  configdict = configdict['vision_config']
132
+
133
  if (
134
  'model_type' in configdict
135
  and hasattr(cls, 'model_type')
 
140
  f'instantiate a model of type {cls.model_type}. This is not supported '
141
  'for all configurations of models and can yield errors.'
142
  )
143
+
144
  return cls.from_dict(configdict, **kwargs)
145
 
146
 
 
159
  use_vision_xformers: Optional[bool] = None,
160
  matryoshka_dimensions: Optional[List[int]] = None,
161
  truncate_dim: Optional[int] = None,
 
162
  **kwargs,
163
  ):
164
  # If `_config_dict` exist, we use them for the backward compatibility.
 
286
  'projections with `add_projections=True`.'
287
  )
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  @classmethod
290
  def from_text_vision_configs(
291
  cls,
eva_model.py CHANGED
@@ -5,19 +5,16 @@
5
 
6
  import math
7
  import os
8
- import warnings
9
  from functools import partial
10
 
11
  import torch
12
  import torch.nn as nn
13
- import torch.nn.functional as f
14
 
15
  try:
16
- warnings.filterwarnings('ignore', category=FutureWarning, module='timm')
17
- from timm.models.layers import drop_path as timm_drop_path
18
- from timm.models.layers import to_2tuple, trunc_normal_
19
  except ImportError or ModuleNotFoundError:
20
- from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
21
 
22
  from .rope_embeddings import VisionRotaryEmbeddingFast
23
 
@@ -84,7 +81,7 @@ class DropPath(nn.Module):
84
  self.drop_prob = drop_prob
85
 
86
  def forward(self, x):
87
- return timm_drop_path(x, self.drop_prob, self.training)
88
 
89
  def extra_repr(self) -> str:
90
  return 'p={}'.format(self.drop_prob)
@@ -247,17 +244,17 @@ class Attention(nn.Module):
247
  self.rope = rope
248
 
249
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
250
- b, n, _ = x.shape
251
  if self.subln:
252
- q = f.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
253
- k = f.linear(input=x, weight=self.k_proj.weight, bias=None)
254
- v = f.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
255
 
256
- q = q.reshape(b, n, self.num_heads, -1).permute(
257
  0, 2, 1, 3
258
  ) # B, num_heads, N, C
259
- k = k.reshape(b, n, self.num_heads, -1).permute(0, 2, 1, 3)
260
- v = v.reshape(b, n, self.num_heads, -1).permute(0, 2, 1, 3)
261
  else:
262
  qkv_bias = None
263
  if self.q_bias is not None:
@@ -269,8 +266,8 @@ class Attention(nn.Module):
269
  )
270
  )
271
 
272
- qkv = f.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
273
- qkv = qkv.reshape(b, n, 3, self.num_heads, -1).permute(
274
  2, 0, 3, 1, 4
275
  ) # 3, B, num_heads, N, C
276
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -301,7 +298,7 @@ class Attention(nn.Module):
301
  p=self.xattn_drop,
302
  scale=self.scale,
303
  )
304
- x = x.reshape(b, n, -1)
305
  x = self.inner_attn_ln(x)
306
  x = self.proj(x)
307
  x = self.proj_drop(x)
@@ -332,7 +329,7 @@ class Attention(nn.Module):
332
  attn = attn.softmax(dim=-1)
333
  attn = self.attn_drop(attn)
334
 
335
- x = (attn @ v).transpose(1, 2).reshape(b, n, -1)
336
  x = self.inner_attn_ln(x)
337
  x = self.proj(x)
338
  x = self.proj_drop(x)
@@ -464,12 +461,12 @@ class PatchEmbed(nn.Module):
464
  in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
465
  )
466
 
467
- def forward(self, x, **_):
468
  target_dtype = self.proj.weight.dtype
469
- _, __, h, w = x.shape
470
  # FIXME look at relaxing size constraints
471
- assert h == self.img_size[0] and w == self.img_size[1], (
472
- f"Input image size ({h}*{w}) doesn't match model "
473
  f'({self.img_size[0]}*{self.img_size[1]}).'
474
  )
475
  x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
@@ -562,8 +559,9 @@ class EVAVisionTransformer(nn.Module):
562
  super().__init__()
563
  self.image_size = img_size
564
  self.num_classes = num_classes
565
- # num_features for consistency with other models
566
- self.num_features = self.embed_dim = embed_dim
 
567
 
568
  self.patch_embed = PatchEmbed(
569
  img_size=img_size,
@@ -668,8 +666,8 @@ class EVAVisionTransformer(nn.Module):
668
  self.grad_checkpointing = grad_checkpointing
669
 
670
  def fix_init_weight(self):
671
- def rescale(param, _layer_id):
672
- param.div_(math.sqrt(2.0 * _layer_id))
673
 
674
  for layer_id, layer in enumerate(self.blocks):
675
  rescale(layer.attn.proj.weight.data, layer_id + 1)
@@ -681,8 +679,7 @@ class EVAVisionTransformer(nn.Module):
681
  def get_cast_dtype(self) -> torch.dtype:
682
  return self.blocks[0].mlp.fc2.weight.dtype
683
 
684
- @staticmethod
685
- def _init_weights(m):
686
  if isinstance(m, nn.Linear):
687
  trunc_normal_(m.weight, std=0.02)
688
  if m.bias is not None:
@@ -694,7 +691,7 @@ class EVAVisionTransformer(nn.Module):
694
  def get_num_layers(self):
695
  return len(self.blocks)
696
 
697
- def lock(self, unlocked_groups=0, *_, **__):
698
  assert (
699
  unlocked_groups == 0
700
  ), 'partial locking not currently supported for this model'
@@ -712,7 +709,7 @@ class EVAVisionTransformer(nn.Module):
712
  def get_classifier(self):
713
  return self.head
714
 
715
- def reset_classifier(self, num_classes, *_, **__):
716
  self.num_classes = num_classes
717
  self.head = (
718
  nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
 
5
 
6
  import math
7
  import os
 
8
  from functools import partial
9
 
10
  import torch
11
  import torch.nn as nn
12
+ import torch.nn.functional as F
13
 
14
  try:
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
 
 
16
  except ImportError or ModuleNotFoundError:
17
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
18
 
19
  from .rope_embeddings import VisionRotaryEmbeddingFast
20
 
 
81
  self.drop_prob = drop_prob
82
 
83
  def forward(self, x):
84
+ return drop_path(x, self.drop_prob, self.training)
85
 
86
  def extra_repr(self) -> str:
87
  return 'p={}'.format(self.drop_prob)
 
244
  self.rope = rope
245
 
246
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
247
+ B, N, C = x.shape
248
  if self.subln:
249
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
250
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
251
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
252
 
253
+ q = q.reshape(B, N, self.num_heads, -1).permute(
254
  0, 2, 1, 3
255
  ) # B, num_heads, N, C
256
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
257
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
258
  else:
259
  qkv_bias = None
260
  if self.q_bias is not None:
 
266
  )
267
  )
268
 
269
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
270
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
271
  2, 0, 3, 1, 4
272
  ) # 3, B, num_heads, N, C
273
  q, k, v = qkv[0], qkv[1], qkv[2]
 
298
  p=self.xattn_drop,
299
  scale=self.scale,
300
  )
301
+ x = x.reshape(B, N, -1)
302
  x = self.inner_attn_ln(x)
303
  x = self.proj(x)
304
  x = self.proj_drop(x)
 
329
  attn = attn.softmax(dim=-1)
330
  attn = self.attn_drop(attn)
331
 
332
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
333
  x = self.inner_attn_ln(x)
334
  x = self.proj(x)
335
  x = self.proj_drop(x)
 
461
  in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
462
  )
463
 
464
+ def forward(self, x, **kwargs):
465
  target_dtype = self.proj.weight.dtype
466
+ B, C, H, W = x.shape
467
  # FIXME look at relaxing size constraints
468
+ assert H == self.img_size[0] and W == self.img_size[1], (
469
+ f"Input image size ({H}*{W}) doesn't match model "
470
  f'({self.img_size[0]}*{self.img_size[1]}).'
471
  )
472
  x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
 
559
  super().__init__()
560
  self.image_size = img_size
561
  self.num_classes = num_classes
562
+ self.num_features = (
563
+ self.embed_dim
564
+ ) = embed_dim # num_features for consistency with other models
565
 
566
  self.patch_embed = PatchEmbed(
567
  img_size=img_size,
 
666
  self.grad_checkpointing = grad_checkpointing
667
 
668
  def fix_init_weight(self):
669
+ def rescale(param, layer_id):
670
+ param.div_(math.sqrt(2.0 * layer_id))
671
 
672
  for layer_id, layer in enumerate(self.blocks):
673
  rescale(layer.attn.proj.weight.data, layer_id + 1)
 
679
  def get_cast_dtype(self) -> torch.dtype:
680
  return self.blocks[0].mlp.fc2.weight.dtype
681
 
682
+ def _init_weights(self, m):
 
683
  if isinstance(m, nn.Linear):
684
  trunc_normal_(m.weight, std=0.02)
685
  if m.bias is not None:
 
691
  def get_num_layers(self):
692
  return len(self.blocks)
693
 
694
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
695
  assert (
696
  unlocked_groups == 0
697
  ), 'partial locking not currently supported for this model'
 
709
  def get_classifier(self):
710
  return self.head
711
 
712
+ def reset_classifier(self, num_classes, global_pool=''):
713
  self.num_classes = num_classes
714
  self.head = (
715
  nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
hf_model.py CHANGED
@@ -1,6 +1,5 @@
1
  import re
2
- import warnings
3
- from typing import Dict, Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
@@ -11,6 +10,10 @@ from transformers.modeling_outputs import (
11
  BaseModelOutputWithPoolingAndCrossAttentions,
12
  )
13
 
 
 
 
 
14
  _HF_ARCH_DICT = {
15
  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
16
  'roberta': {
@@ -38,6 +41,22 @@ _HF_ARCH_DICT = {
38
  },
39
  'pooler': 'mean_pooler',
40
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # https://huggingface.co/docs/transformers/model_doc/bert
42
  'bert': {
43
  'config_names': {
@@ -49,8 +68,24 @@ _HF_ARCH_DICT = {
49
  },
50
  'pooler': 'cls_pooler',
51
  },
 
 
 
 
 
 
 
 
 
 
 
52
  }
53
 
 
 
 
 
 
54
  _POOLERS = {}
55
 
56
 
@@ -66,6 +101,8 @@ def register_pooler(cls):
66
 
67
  @register_pooler
68
  class MeanPooler(nn.Module):
 
 
69
  @staticmethod
70
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
71
  masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
@@ -74,6 +111,10 @@ class MeanPooler(nn.Module):
74
 
75
  @register_pooler
76
  class MaxPooler(nn.Module):
 
 
 
 
77
  @staticmethod
78
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
79
  masked_output = x.last_hidden_state.masked_fill(
@@ -84,7 +125,11 @@ class MaxPooler(nn.Module):
84
 
85
  @register_pooler
86
  class ClsPooler(nn.Module):
87
- def __init__(self, use_pooler_output: bool = True):
 
 
 
 
88
  super().__init__()
89
  self.cls_token_position = 0
90
  self.use_pooler_output = use_pooler_output
@@ -102,9 +147,15 @@ class ClsPooler(nn.Module):
102
  and (x.pooler_output is not None)
103
  ):
104
  return x.pooler_output
 
105
  return x.last_hidden_state[:, self.cls_token_position, :]
106
 
107
 
 
 
 
 
 
108
  class HFTextEncoder(nn.Module):
109
  output_tokens: torch.jit.Final[bool]
110
 
@@ -120,60 +171,56 @@ class HFTextEncoder(nn.Module):
120
  output_tokens: bool = False,
121
  trust_remote_code: bool = False,
122
  revision: Optional[str] = None,
123
- code_revision: Optional[str] = None,
124
- default_instruction_task: Optional[str] = None,
125
- default_lora_task: Optional[str] = None,
126
  model_config_kwargs: Optional[Dict] = None,
127
  ):
128
  super().__init__()
129
  self.output_tokens = output_tokens
130
  self.output_dim = output_dim
131
 
 
 
132
  model_config_kwargs = model_config_kwargs or {}
133
 
134
  if config is None:
135
- if pretrained:
136
- self.transformer = AutoModel.from_pretrained(
137
- model_name_or_path,
138
- trust_remote_code=trust_remote_code,
139
- revision=revision,
140
- add_pooling_layer=False,
141
- code_revision=code_revision,
142
- **model_config_kwargs,
143
- )
144
- self.config = self.transformer.config
145
- else:
146
- self.config = AutoConfig.from_pretrained(
147
- model_name_or_path,
148
- trust_remote_code=trust_remote_code,
149
- code_revision=code_revision,
150
- )
151
- self.config.update(model_config_kwargs)
152
- self.transformer = AutoModel.from_config(
153
- self.config,
154
- trust_remote_code=trust_remote_code,
155
- add_pooling_layer=False,
156
- code_revision=code_revision,
157
- )
158
  if (
159
  hasattr(self.config, 'is_encoder_decoder')
160
  and self.config.is_encoder_decoder
161
  ):
 
162
  self.transformer = self.transformer.encoder
163
-
 
 
 
 
 
 
164
  else:
165
  self.config = config
166
  self.config.update(model_config_kwargs)
167
- self.transformer = AutoModel.from_config(
168
- self.config,
169
- trust_remote_code=trust_remote_code,
170
- revision=revision,
171
- code_revision=code_revision,
172
- )
 
173
  self.vocab_size = getattr(self.config, 'vocab_size', 0)
174
  self.context_length = getattr(self.config, 'max_position_embeddings', 0)
175
 
176
- pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler']
177
  self.pooler = _POOLERS[pooler_type]()
178
 
179
  d_model = getattr(
@@ -181,7 +228,7 @@ class HFTextEncoder(nn.Module):
181
  )
182
  if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
183
  self.proj = nn.Identity()
184
- elif (d_model != output_dim) or proj_type == 'linear':
185
  self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
186
  elif proj_type == 'mlp':
187
  hidden_size = (d_model + output_dim) // 2
@@ -191,149 +238,27 @@ class HFTextEncoder(nn.Module):
191
  nn.Linear(hidden_size, output_dim, bias=proj_bias),
192
  )
193
 
194
- self._task_instructions = {}
195
- self._lora_adaptation_map = {}
196
- self._supports_task_instructions = False
197
- self._supports_lora = False
198
- if (
199
- hasattr(self.transformer, '_adaptation_map')
200
- and len(self.transformer._adaptation_map) > 0
201
- ):
202
- self._lora_adaptation_map = self.transformer._adaptation_map
203
- self._supports_lora = True
204
- if (
205
- hasattr(self.transformer, '_task_instructions')
206
- and len(self.transformer._task_instructions) > 0
207
- ):
208
- self._task_instructions = self.transformer._task_instructions
209
- self._supports_task_instructions = True
210
-
211
- self._default_instruction_task = None
212
- self._default_lora_task = None
213
- self._default_instruction = None
214
- self._default_loraid = None
215
-
216
- if default_instruction_task is not None:
217
- self._default_instruction_task = default_instruction_task
218
- self._default_instruction = self.get_instruction_from_task(
219
- default_instruction_task
220
- )
221
- if default_lora_task is not None:
222
- self._default_lora_task = default_lora_task
223
- self._default_loraid = self.get_loraid_from_task(default_lora_task)
224
-
225
- @property
226
- def supports_task_instructions(self) -> bool:
227
- return self._supports_task_instructions
228
-
229
- @property
230
- def supports_lora(self) -> bool:
231
- return self._supports_lora
232
-
233
- @property
234
- def task_instructions(self) -> Dict[str, str]:
235
- return self._task_instructions
236
-
237
- @property
238
- def lora_adaptation_map(self) -> Dict[str, int]:
239
- return self._lora_adaptation_map
240
-
241
- @property
242
- def default_instruction(self) -> Optional[str]:
243
- return self._default_instruction
244
-
245
- @property
246
- def default_loraid(self) -> Optional[int]:
247
- return self._default_loraid
248
-
249
- def get_instruction_from_task(self, task: Optional[str]) -> Optional[str]:
250
- if self._supports_task_instructions:
251
- if task is None:
252
- return self._default_instruction
253
- if task not in self._task_instructions:
254
- raise ValueError(
255
- f'Unsupported task \'{task}\'. Choose one of the following: '
256
- f'{", ".join(self._task_instructions)} or set to None to disable '
257
- f'task instructions completely'
258
- )
259
- return self._task_instructions[task]
260
- else:
261
- if task is not None:
262
- warnings.warn(
263
- 'Model does not support task instructions, ignoring instruction '
264
- f"task '{task}'"
265
- )
266
- return None
267
-
268
- def get_loraid_from_task(self, task: Optional[str]) -> Optional[int]:
269
- if self._supports_lora:
270
- if task is None:
271
- return self._default_loraid
272
- if task not in self._lora_adaptation_map:
273
- raise ValueError(
274
- f'Unsupported task \'{task}\'. Choose one of the following: '
275
- f'{", ".join(self._task_instructions)} or set to None to disable '
276
- f'the LoRA adapters completely'
277
- )
278
- return self._lora_adaptation_map[task]
279
- else:
280
- if task is not None:
281
- warnings.warn(
282
- f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
283
- )
284
- return None
285
-
286
- @staticmethod
287
- def get_adapter_mask_from_loraid(
288
- batch_size: int, loraid: int, device: Union[str, torch.device]
289
- ):
290
- return torch.full((batch_size,), loraid, dtype=torch.int32, device=device)
291
-
292
- @torch.jit.ignore
293
- def set_grad_checkpointing(self, _=True):
294
- self.transformer.gradient_checkpointing_enable()
295
-
296
- def init_parameters(self):
297
- pass
298
-
299
- def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
300
- if adapter_mask is None:
301
- default_loraid = self.default_loraid
302
- if default_loraid is not None:
303
- adapter_mask = self.get_adapter_mask_from_loraid(
304
- x.shape[0], default_loraid, x.device
305
- )
306
- else:
307
- if not self.supports_lora:
308
- warnings.warn(
309
- 'Model does not support LoRA adapters, setting adapter_mask to None'
310
- )
311
- adapter_mask = None
312
-
313
- attention_mask = (x != self.config.pad_token_id).long()
314
- lora_kwargs = {}
315
- if adapter_mask is not None:
316
- lora_kwargs['adapter_mask'] = adapter_mask
317
-
318
- out = self.transformer(
319
- input_ids=x, attention_mask=attention_mask, **lora_kwargs
320
- )
321
- pooled_out = self.pooler(out, attention_mask)
322
  projected = self.proj(pooled_out)
323
- seqlen = out.last_hidden_state.shape[1]
 
324
  tokens = (
325
  out.last_hidden_state[
326
- :, torch.arange(seqlen) != self.pooler.cls_token_position, :
327
  ]
328
  if isinstance(self.pooler, ClsPooler)
329
  else out.last_hidden_state
330
  )
 
331
  if self.output_tokens:
332
  return projected, tokens
333
  return projected
334
 
335
  def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
336
- if not unlocked_layers:
337
  for n, p in self.transformer.named_parameters():
338
  p.requires_grad = (
339
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
@@ -362,3 +287,11 @@ class HFTextEncoder(nn.Module):
362
  p.requires_grad = (
363
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
364
  )
 
 
 
 
 
 
 
 
 
1
  import re
2
+ from typing import Dict, Optional, Tuple
 
3
 
4
  import torch
5
  import torch.nn as nn
 
10
  BaseModelOutputWithPoolingAndCrossAttentions,
11
  )
12
 
13
+ """
14
+ HF architecture mapping
15
+ """
16
+
17
  _HF_ARCH_DICT = {
18
  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
19
  'roberta': {
 
41
  },
42
  'pooler': 'mean_pooler',
43
  },
44
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
45
+ 'mt5': {
46
+ 'config_names': {
47
+ # unlimited seqlen
48
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
49
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
50
+ 'context_length': '',
51
+ 'vocab_size': 'vocab_size',
52
+ 'width': 'd_model',
53
+ 'heads': 'num_heads',
54
+ 'layers': 'num_layers',
55
+ 'layer_attr': 'block',
56
+ 'token_embeddings_attr': 'embed_tokens',
57
+ },
58
+ 'pooler': 'mean_pooler',
59
+ },
60
  # https://huggingface.co/docs/transformers/model_doc/bert
61
  'bert': {
62
  'config_names': {
 
68
  },
69
  'pooler': 'cls_pooler',
70
  },
71
+ # https://huggingface.co/docs/transformers/model_doc/m2m_100
72
+ 'm2m_100': {
73
+ 'config_names': {
74
+ 'context_length': 'max_position_embeddings',
75
+ 'vocab_size': 'vocab_size',
76
+ 'width': 'd_model',
77
+ 'heads': 'encoder_attention_heads',
78
+ 'layers': 'encoder_layers',
79
+ },
80
+ 'pooler': 'cls_pooler',
81
+ },
82
  }
83
 
84
+
85
+ """
86
+ Pooling functions
87
+ """
88
+
89
  _POOLERS = {}
90
 
91
 
 
101
 
102
  @register_pooler
103
  class MeanPooler(nn.Module):
104
+ """Mean pooling"""
105
+
106
  @staticmethod
107
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
108
  masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
 
111
 
112
  @register_pooler
113
  class MaxPooler(nn.Module):
114
+ """
115
+ Max pooling
116
+ """
117
+
118
  @staticmethod
119
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
120
  masked_output = x.last_hidden_state.masked_fill(
 
125
 
126
  @register_pooler
127
  class ClsPooler(nn.Module):
128
+ """
129
+ CLS token pooling
130
+ """
131
+
132
+ def __init__(self, use_pooler_output=True):
133
  super().__init__()
134
  self.cls_token_position = 0
135
  self.use_pooler_output = use_pooler_output
 
147
  and (x.pooler_output is not None)
148
  ):
149
  return x.pooler_output
150
+
151
  return x.last_hidden_state[:, self.cls_token_position, :]
152
 
153
 
154
+ """
155
+ HF text model
156
+ """
157
+
158
+
159
  class HFTextEncoder(nn.Module):
160
  output_tokens: torch.jit.Final[bool]
161
 
 
171
  output_tokens: bool = False,
172
  trust_remote_code: bool = False,
173
  revision: Optional[str] = None,
 
 
 
174
  model_config_kwargs: Optional[Dict] = None,
175
  ):
176
  super().__init__()
177
  self.output_tokens = output_tokens
178
  self.output_dim = output_dim
179
 
180
+ # TODO: find better way to get this information
181
+ uses_transformer_pooler = pooler_type == 'cls_pooler'
182
  model_config_kwargs = model_config_kwargs or {}
183
 
184
  if config is None:
185
+ self.config = AutoConfig.from_pretrained(
186
+ model_name_or_path,
187
+ trust_remote_code=trust_remote_code,
188
+ code_revision=revision,
189
+ )
190
+ self.config.update(model_config_kwargs)
191
+ create_func, model_args = (
192
+ (AutoModel.from_pretrained, model_name_or_path)
193
+ if pretrained
194
+ else (AutoModel.from_config, self.config)
195
+ )
196
+ # TODO: do all model configs have this attribute?
197
+ # PretrainedConfig does so yes??
 
 
 
 
 
 
 
 
 
 
198
  if (
199
  hasattr(self.config, 'is_encoder_decoder')
200
  and self.config.is_encoder_decoder
201
  ):
202
+ self.transformer = create_func(model_args)
203
  self.transformer = self.transformer.encoder
204
+ else:
205
+ self.transformer = create_func(
206
+ model_args,
207
+ trust_remote_code=trust_remote_code,
208
+ add_pooling_layer=uses_transformer_pooler,
209
+ code_revision=revision,
210
+ )
211
  else:
212
  self.config = config
213
  self.config.update(model_config_kwargs)
214
+ self.transformer = AutoModel.from_config(self.config)
215
+
216
+ if pooler_type is None: # get default arch pooler
217
+ pooler_type = _HF_ARCH_DICT[self.config.model_type]['pooler']
218
+
219
+ # FIXME downstream users of OpenCLIP models use these attr,
220
+ # need to verify valid across all models
221
  self.vocab_size = getattr(self.config, 'vocab_size', 0)
222
  self.context_length = getattr(self.config, 'max_position_embeddings', 0)
223
 
 
224
  self.pooler = _POOLERS[pooler_type]()
225
 
226
  d_model = getattr(
 
228
  )
229
  if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
230
  self.proj = nn.Identity()
231
+ elif proj_type == 'linear':
232
  self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
233
  elif proj_type == 'mlp':
234
  hidden_size = (d_model + output_dim) // 2
 
238
  nn.Linear(hidden_size, output_dim, bias=proj_bias),
239
  )
240
 
241
+ def forward(self, x: torch.Tensor):
242
+ attn_mask = (x != self.config.pad_token_id).long()
243
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
244
+ pooled_out = self.pooler(out, attn_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  projected = self.proj(pooled_out)
246
+
247
+ seq_len = out.last_hidden_state.shape[1]
248
  tokens = (
249
  out.last_hidden_state[
250
+ :, torch.arange(seq_len) != self.pooler.cls_token_position, :
251
  ]
252
  if isinstance(self.pooler, ClsPooler)
253
  else out.last_hidden_state
254
  )
255
+
256
  if self.output_tokens:
257
  return projected, tokens
258
  return projected
259
 
260
  def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
261
+ if not unlocked_layers: # full freezing
262
  for n, p in self.transformer.named_parameters():
263
  p.requires_grad = (
264
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
 
287
  p.requires_grad = (
288
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
289
  )
290
+
291
+ @torch.jit.ignore
292
+ def set_grad_checkpointing(self, _=True):
293
+ self.transformer.gradient_checkpointing_enable()
294
+
295
+ def init_parameters(self):
296
+ pass
297
+
modeling_clip.py CHANGED
@@ -5,8 +5,6 @@
5
  # and adjusted for Jina CLIP
6
 
7
  import base64
8
- import importlib.util
9
- import warnings
10
  from functools import partial
11
  from io import BytesIO
12
  from typing import List, Optional, Tuple, Union
@@ -16,7 +14,6 @@ import requests
16
  import torch
17
  import torch.nn.functional as f
18
  import torch.utils.checkpoint
19
- from PIL import Image
20
  from torch import nn
21
  from transformers import (
22
  AutoImageProcessor,
@@ -38,12 +35,13 @@ try:
38
 
39
  has_tqdm = True
40
  except ImportError:
41
- trange = None
42
  has_tqdm = False
43
 
44
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
45
  from .eva_model import EVAVisionTransformer
46
  from .hf_model import HFTextEncoder
 
 
47
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
48
  from .transform import ( # noqa: F401
49
  OPENAI_DATASET_MEAN,
@@ -70,8 +68,6 @@ def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
70
  return HFTextEncoder(
71
  model_name_or_path=config.hf_model_name_or_path,
72
  output_dim=config.embed_dim,
73
- default_instruction_task=config.default_instruction_task,
74
- default_lora_task=config.default_lora_task,
75
  pooler_type=config.pooler_type,
76
  proj_type=config.proj_type,
77
  proj_bias=config.proj_bias,
@@ -119,80 +115,6 @@ def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
119
  )
120
 
121
 
122
- def _resolve_attention_libs(config: JinaCLIPConfig):
123
- use_text_flash_attn = (
124
- config.use_text_flash_attn
125
- if config.use_text_flash_attn is not None
126
- else config.text_config.hf_model_config_kwargs.get('use_flash_attn', True)
127
- )
128
- use_vision_xformers = (
129
- config.use_vision_xformers
130
- if config.use_vision_xformers is not None
131
- else config.vision_config.x_attention
132
- )
133
-
134
- def _resolve_use_text_flash_attn() -> bool:
135
- if use_text_flash_attn:
136
- if not torch.cuda.is_available():
137
- warnings.warn('Flash attention requires CUDA, disabling')
138
- return False
139
- if importlib.util.find_spec('flash_attn') is None:
140
- warnings.warn(
141
- 'Flash attention is not installed. Check '
142
- 'https://github.com/Dao-AILab/flash-attention?'
143
- 'tab=readme-ov-file#installation-and-features '
144
- 'for installation instructions, disabling'
145
- )
146
- return False
147
- major, minor, *_ = torch.version.cuda.split('.')
148
- major, minor = int(major), int(minor)
149
- if major < 11 or (major == 11 and minor < 7):
150
- warnings.warn(
151
- 'Flash attention requires CUDA>=11.7. Found version '
152
- f'{major}.{minor}, disabling'
153
- )
154
- return False
155
- capability = torch.cuda.get_device_capability()
156
- major, *_ = capability
157
- major = int(major)
158
- if major < 8:
159
- device_name = torch.cuda.get_device_properties(0).name
160
- warnings.warn(
161
- 'Flash attention requires device capability>=8.0 (NVIDIA Ampere, '
162
- f'Hopper or ADA). Found device {device_name} with capability '
163
- f'{capability}, disabling'
164
- )
165
- return False
166
- return True
167
- return False
168
-
169
- def _resolve_use_vision_xformers() -> bool:
170
- if use_vision_xformers:
171
- if not torch.cuda.is_available():
172
- warnings.warn('xFormers requires CUDA, disabling')
173
- return False
174
- if importlib.util.find_spec('xformers') is None:
175
- warnings.warn(
176
- 'xFormers is not installed. Check '
177
- 'https://github.com/facebookresearch/xformers?'
178
- 'tab=readme-ov-file#installing-xformers for installation '
179
- 'instructions, disabling'
180
- )
181
- return False
182
- return True
183
- return False
184
-
185
- _use_text_flash_attn = _resolve_use_text_flash_attn()
186
- _use_vision_xformers = _resolve_use_vision_xformers()
187
-
188
- config.use_text_flash_attn = _use_text_flash_attn
189
- config.use_vision_xformers = _use_vision_xformers
190
- config.text_config.hf_model_config_kwargs['use_flash_attn'] = _use_text_flash_attn
191
- config.vision_config.x_attention = _use_vision_xformers
192
-
193
- return config
194
-
195
-
196
  class JinaCLIPPreTrainedModel(PreTrainedModel):
197
  """
198
  An abstract class to handle weights initialization and a simple interface for
@@ -222,12 +144,6 @@ class JinaCLIPPreTrainedModel(PreTrainedModel):
222
  if isinstance(module, nn.Linear) and module.bias is not None:
223
  module.bias.data.zero_()
224
 
225
- @classmethod
226
- def from_pretrained(cls, *args, **kwargs):
227
- if 'torch_dtype' not in kwargs:
228
- kwargs['torch_dtype'] = 'auto'
229
- return super().from_pretrained(*args, **kwargs)
230
-
231
 
232
  class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
233
  config_class = JinaCLIPTextConfig
@@ -300,19 +216,25 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
300
  f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
301
  )
302
 
303
- config = _resolve_attention_libs(config)
304
  text_config = config.text_config
305
  vision_config = config.vision_config
306
 
 
 
 
 
 
307
  self.add_projections = config.add_projections
308
  self.projection_dim = config.projection_dim
309
  self.text_embed_dim = text_config.embed_dim
310
  self.vision_embed_dim = vision_config.embed_dim
 
311
  self.text_model = _build_text_tower(text_config)
312
  self.vision_model = _build_vision_tower(vision_config)
313
  self.logit_scale = nn.Parameter(
314
  torch.tensor(self.config.logit_scale_init_value)
315
  )
 
316
  if self.add_projections:
317
  self.visual_projection = nn.Linear(
318
  self.vision_embed_dim, self.projection_dim, bias=False
@@ -329,7 +251,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
329
  self.post_init()
330
 
331
  def get_tokenizer(self):
332
- if self.tokenizer is None:
333
  self.tokenizer = AutoTokenizer.from_pretrained(
334
  self.config._name_or_path, trust_remote_code=True
335
  )
@@ -364,24 +286,24 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
364
  )
365
  return self.visual_projection(self.vision_model(x=x))
366
 
367
- def _truncate_embeddings(self, embeddings: torch.Tensor, truncate_dim: int):
368
  if not self.config.matryoshka_dimensions:
369
  logger.warning(
370
- 'Model is not trained using Matryoshka Representation Learning, '
371
- 'truncating embeddings will not work optimally.'
 
 
 
 
 
 
 
372
  )
373
- return embeddings[:, :truncate_dim]
374
-
375
- @staticmethod
376
- def _decode_image_data(image_data_str: str) -> Image:
377
- header, data = image_data_str.split(',', 1)
378
- image_data = base64.b64decode(data)
379
- return Image.open(BytesIO(image_data))
380
 
381
  @torch.inference_mode()
382
- def encode_image(
383
  self,
384
- images: Union[str, List[Union[str, 'Image.Image']]],
385
  batch_size: int = 32,
386
  show_progress_bar: Optional[bool] = None,
387
  convert_to_numpy: bool = True,
@@ -389,129 +311,122 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
389
  device: Optional[torch.device] = None,
390
  normalize_embeddings: bool = True,
391
  truncate_dim: Optional[int] = None,
 
392
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
393
  """
394
- Computes image embeddings
395
-
396
- Args:
397
- images(`str` or `List[Union[str, Image.Image]]`):
398
- Image paths, URLs, PIL images, or data:image/ strings to be encoded
399
- batch_size(`int`, *optional*, defaults to 32):
400
- Batch size for the computation
401
- show_progress_bar(`bool`, *optional*, defaults to None):
402
- Show a progress bar when encoding images. If set to None, progress bar
403
- is only shown when `logger.level == logging.INFO` or
404
- `logger.level == logging.DEBUG`
405
- convert_to_numpy(`bool`, *optional*, defaults to True):
406
- If true, the output is a list of numpy vectors. Else, it is a list of
407
- pytorch tensors
408
- convert_to_tensor(`bool`, *optional*, defaults to False):
409
- If true, you get one large tensor as return. Overwrites any setting
410
- from convert_to_numpy
411
- device(`torch.device`, *optional*, defaults to None):
412
- Which torch.device to use for the computation
413
- normalize_embeddings(`bool`, *optional*, defaults to True):
414
- If set to true, returned vectors will have length 1. In that case,
415
- the faster dot-product (util.dot_score) instead of cosine similarity
416
- can be used
417
- truncate_dim(`int`, *optional*, defaults to None):
418
- The dimension to truncate sentence embeddings to. If set to `None`
419
- no truncation is performed
420
-
421
- Returns:
422
- By default, a list of tensors is returned. If convert_to_tensor, a stacked
423
- tensor is returned. If convert_to_numpy, a numpy matrix is returned
424
  """
425
-
426
- _is_training = self.training
427
  self.eval()
428
-
429
- self.preprocess = self.get_preprocess()
430
  all_embeddings = []
431
 
 
 
432
  if show_progress_bar is None:
433
  show_progress_bar = (
434
  logger.getEffectiveLevel() == logging.INFO
435
  or logger.getEffectiveLevel() == logging.DEBUG
436
  )
 
437
  if convert_to_tensor:
438
  convert_to_numpy = False
439
 
440
- _input_was_single_img = False
441
- if isinstance(images, str) or not hasattr(images, '__len__'):
442
- images = [images]
443
- _input_was_single_img = True
444
 
445
  if device is not None:
446
  self.to(device)
447
 
448
- _permutation = np.argsort([-len(str(i)) for i in images])
449
- _inverse_permutation = np.argsort(_permutation)
450
- images = [images[idx] for idx in _permutation]
 
 
 
 
451
 
452
  if has_tqdm:
453
  range_iter = trange(
454
  0,
455
- len(images),
456
  batch_size,
457
  desc='Encoding',
458
  disable=not show_progress_bar,
459
  )
460
  else:
461
- range_iter = range(0, len(images), batch_size)
462
 
463
  truncate_dim = truncate_dim or self.config.truncate_dim
464
-
465
  for i in range_iter:
466
- _processed_images = []
467
- for img in images[i: i + batch_size]:
468
- if isinstance(img, str):
469
- if img.startswith('http'):
470
- response = requests.get(img)
471
- image = Image.open(BytesIO(response.content)).convert('RGB')
472
- elif img.startswith('data:image/'):
473
- image = self._decode_image_data(img).convert('RGB')
474
- else:
475
- image = Image.open(img).convert('RGB')
476
- elif isinstance(img, Image.Image):
477
- image = img.convert('RGB')
478
- else:
479
- raise ValueError('Unsupported image format')
480
- _processed_images.append(image)
481
 
482
- pixelvals = self.preprocess(_processed_images)
483
- pixelvals = pixelvals.to(self.device)
484
- embeddings = self.get_image_features(pixelvals)
485
 
486
  if truncate_dim:
487
- embeddings = self._truncate_embeddings(embeddings, truncate_dim)
488
  if normalize_embeddings:
489
- embeddings = f.normalize(embeddings, p=2, dim=1)
490
  if convert_to_numpy:
491
  embeddings = embeddings.cpu()
492
-
493
  all_embeddings.extend(embeddings)
494
 
495
- all_embeddings = [all_embeddings[idx] for idx in _inverse_permutation]
496
 
497
  if convert_to_tensor:
498
  all_embeddings = torch.stack(all_embeddings)
499
  elif convert_to_numpy:
500
- all_embeddings = np.asarray(
501
- [emb.to(torch.float32).numpy() for emb in all_embeddings]
502
- )
503
 
504
- if _input_was_single_img:
505
  all_embeddings = all_embeddings[0]
506
 
507
- self.train(_is_training)
508
  return all_embeddings
509
 
 
 
 
 
 
510
  @torch.inference_mode()
511
- def encode_text(
512
  self,
513
- sentences: Union[str, List[str]],
514
- task: Optional[str] = None,
515
  batch_size: int = 32,
516
  show_progress_bar: Optional[bool] = None,
517
  convert_to_numpy: bool = True,
@@ -519,119 +434,123 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
519
  device: Optional[torch.device] = None,
520
  normalize_embeddings: bool = True,
521
  truncate_dim: Optional[int] = None,
522
- **tokenizer_kwargs,
523
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
524
  """
525
- Computes text embeddings
526
-
527
  Args:
528
- sentences(`str` or `List[str]`):
529
- Sentence or sentences to be encoded
530
- task(`str`, *optional*, defaults to `None`):
531
- Specifies the task for which the encoding is intended. If a `task` is
532
- provided, a task-specific instruction is added to the beginning of each
533
- sentence. If `task` is not provided, no instructions are added.
534
  batch_size(`int`, *optional*, defaults to 32):
535
  Batch size for the computation
536
  show_progress_bar(`bool`, *optional*, defaults to None):
537
- Show a progress bar when encoding sentences. If set to None, progress
538
- bar is only shown when `logger.level == logging.INFO` or
539
- `logger.level == logging.DEBUG`
540
  convert_to_numpy(`bool`, *optional*, defaults to True):
541
- If true, the output is a list of numpy vectors. Else, it is a list of
542
- pytorch tensors
543
  convert_to_tensor(`bool`, *optional*, defaults to False):
544
- If true, you get one large tensor as return. Overwrites any setting
545
- from convert_to_numpy
546
  device(`torch.device`, *optional*, defaults to None):
547
  Which torch.device to use for the computation
548
- normalize_embeddings(`bool`, *optional*, defaults to True):
549
  If set to true, returned vectors will have length 1. In that case,
550
  the faster dot-product (util.dot_score) instead of cosine similarity
551
- can be used
552
  truncate_dim(`int`, *optional*, defaults to None):
553
- The dimension to truncate sentence embeddings to. If set to `None`
554
- no truncation is performed
555
- tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
556
- Keyword arguments for the tokenizer
557
  Returns:
558
- By default, a list of tensors is returned. If convert_to_tensor, a stacked
559
- tensor is returned. If convert_to_numpy, a numpy matrix is returned.
 
560
  """
561
- _is_training = self.training
 
562
  self.eval()
563
-
 
564
  all_embeddings = []
565
- self.tokenizer = self.get_tokenizer()
566
-
567
  if show_progress_bar is None:
568
  show_progress_bar = (
569
  logger.getEffectiveLevel() == logging.INFO
570
  or logger.getEffectiveLevel() == logging.DEBUG
571
  )
 
572
  if convert_to_tensor:
573
  convert_to_numpy = False
574
-
575
- _input_was_string = False
576
- if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
577
- sentences = [sentences]
578
- _input_was_string = True
579
-
580
  if device is not None:
581
  self.to(device)
582
-
583
- _permutation = np.argsort([-len(i) for i in sentences])
584
- _inverse_permutation = np.argsort(_permutation)
585
- sentences = [sentences[idx] for idx in _permutation]
586
-
587
- tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
588
- tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
589
- tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
590
-
591
  if has_tqdm:
592
  range_iter = trange(
593
  0,
594
- len(sentences),
595
  batch_size,
596
  desc='Encoding',
597
  disable=not show_progress_bar,
598
  )
599
  else:
600
- range_iter = range(0, len(sentences), batch_size)
601
-
602
- truncate_dim = truncate_dim or self.config.truncate_dim
603
 
604
- instruction = self.text_model.get_instruction_from_task(task)
605
- if instruction:
606
- sentences = [instruction + sentence for sentence in sentences]
607
 
 
608
  for i in range_iter:
609
- tokens = self.tokenizer(
610
- sentences[i: i + batch_size],
611
- return_tensors='pt',
612
- **tokenizer_kwargs,
613
- ).to(self.device)
614
- embeddings = self.get_text_features(input_ids=tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  if truncate_dim:
616
- embeddings = self._truncate_embeddings(embeddings, truncate_dim)
617
  if normalize_embeddings:
618
- embeddings = f.normalize(embeddings, p=2, dim=1)
619
  if convert_to_numpy:
620
  embeddings = embeddings.cpu()
621
  all_embeddings.extend(embeddings)
622
-
623
- all_embeddings = [all_embeddings[idx] for idx in _inverse_permutation]
624
-
625
  if convert_to_tensor:
626
  all_embeddings = torch.stack(all_embeddings)
627
  elif convert_to_numpy:
628
- all_embeddings = np.asarray(
629
- [emb.to(torch.float32).numpy() for emb in all_embeddings]
630
- )
631
- if _input_was_string:
632
  all_embeddings = all_embeddings[0]
633
-
634
- self.train(_is_training)
635
  return all_embeddings
636
 
637
  def forward(
 
5
  # and adjusted for Jina CLIP
6
 
7
  import base64
 
 
8
  from functools import partial
9
  from io import BytesIO
10
  from typing import List, Optional, Tuple, Union
 
14
  import torch
15
  import torch.nn.functional as f
16
  import torch.utils.checkpoint
 
17
  from torch import nn
18
  from transformers import (
19
  AutoImageProcessor,
 
35
 
36
  has_tqdm = True
37
  except ImportError:
 
38
  has_tqdm = False
39
 
40
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
41
  from .eva_model import EVAVisionTransformer
42
  from .hf_model import HFTextEncoder
43
+
44
+ # needed for HF to correctly import in cache
45
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
46
  from .transform import ( # noqa: F401
47
  OPENAI_DATASET_MEAN,
 
68
  return HFTextEncoder(
69
  model_name_or_path=config.hf_model_name_or_path,
70
  output_dim=config.embed_dim,
 
 
71
  pooler_type=config.pooler_type,
72
  proj_type=config.proj_type,
73
  proj_bias=config.proj_bias,
 
115
  )
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  class JinaCLIPPreTrainedModel(PreTrainedModel):
119
  """
120
  An abstract class to handle weights initialization and a simple interface for
 
144
  if isinstance(module, nn.Linear) and module.bias is not None:
145
  module.bias.data.zero_()
146
 
 
 
 
 
 
 
147
 
148
  class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
149
  config_class = JinaCLIPTextConfig
 
216
  f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
217
  )
218
 
 
219
  text_config = config.text_config
220
  vision_config = config.vision_config
221
 
222
+ if config.use_text_flash_attn is not None:
223
+ text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
224
+ if config.use_vision_xformers is not None:
225
+ vision_config.x_attention = config.use_vision_xformers
226
+
227
  self.add_projections = config.add_projections
228
  self.projection_dim = config.projection_dim
229
  self.text_embed_dim = text_config.embed_dim
230
  self.vision_embed_dim = vision_config.embed_dim
231
+
232
  self.text_model = _build_text_tower(text_config)
233
  self.vision_model = _build_vision_tower(vision_config)
234
  self.logit_scale = nn.Parameter(
235
  torch.tensor(self.config.logit_scale_init_value)
236
  )
237
+
238
  if self.add_projections:
239
  self.visual_projection = nn.Linear(
240
  self.vision_embed_dim, self.projection_dim, bias=False
 
251
  self.post_init()
252
 
253
  def get_tokenizer(self):
254
+ if not self.tokenizer:
255
  self.tokenizer = AutoTokenizer.from_pretrained(
256
  self.config._name_or_path, trust_remote_code=True
257
  )
 
286
  )
287
  return self.visual_projection(self.vision_model(x=x))
288
 
289
+ def truncate_embeddings(self, embeddings, truncate_dim):
290
  if not self.config.matryoshka_dimensions:
291
  logger.warning(
292
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
293
+ )
294
+ return embeddings
295
+ elif truncate_dim in self.config.matryoshka_dimensions:
296
+ return embeddings[:, :truncate_dim]
297
+ else:
298
+ raise ValueError(
299
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
300
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
301
  )
 
 
 
 
 
 
 
302
 
303
  @torch.inference_mode()
304
+ def encode_text(
305
  self,
306
+ sentences: Union[str, List[str]],
307
  batch_size: int = 32,
308
  show_progress_bar: Optional[bool] = None,
309
  convert_to_numpy: bool = True,
 
311
  device: Optional[torch.device] = None,
312
  normalize_embeddings: bool = True,
313
  truncate_dim: Optional[int] = None,
314
+ **tokenizer_kwargs,
315
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
316
  """
317
+ Computes sentence embeddings
318
+ Args:
319
+ sentences(`str` or `List[str]`):
320
+ Sentence or sentences to be encoded
321
+ batch_size(`int`, *optional*, defaults to 32):
322
+ Batch size for the computation
323
+ show_progress_bar(`bool`, *optional*, defaults to None):
324
+ Show a progress bar when encoding sentences.
325
+ If set to None, progress bar is only shown when
326
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
327
+ convert_to_numpy(`bool`, *optional*, defaults to True):
328
+ If true, the output is a list of numpy vectors.
329
+ Else, it is a list of pytorch tensors.
330
+ convert_to_tensor(`bool`, *optional*, defaults to False):
331
+ If true, you get one large tensor as return.
332
+ Overwrites any setting from convert_to_numpy
333
+ device(`torch.device`, *optional*, defaults to None):
334
+ Which torch.device to use for the computation
335
+ normalize_embeddings(`bool`, *optional*, defaults to False):
336
+ If set to true, returned vectors will have length 1. In that case,
337
+ the faster dot-product (util.dot_score) instead of cosine similarity
338
+ can be used.
339
+ truncate_dim(`int`, *optional*, defaults to None):
340
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
341
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
342
+ Keyword arguments for the tokenizer
343
+ Returns:
344
+ By default, a list of tensors is returned.
345
+ If convert_to_tensor, a stacked tensor is returned.
346
+ If convert_to_numpy, a numpy matrix is returned.
347
  """
348
+ is_training = self.training
 
349
  self.eval()
 
 
350
  all_embeddings = []
351
 
352
+ self.tokenizer = self.get_tokenizer()
353
+
354
  if show_progress_bar is None:
355
  show_progress_bar = (
356
  logger.getEffectiveLevel() == logging.INFO
357
  or logger.getEffectiveLevel() == logging.DEBUG
358
  )
359
+
360
  if convert_to_tensor:
361
  convert_to_numpy = False
362
 
363
+ input_was_string = False
364
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
365
+ sentences = [sentences]
366
+ input_was_string = True
367
 
368
  if device is not None:
369
  self.to(device)
370
 
371
+ permutation = np.argsort([-len(i) for i in sentences])
372
+ inverse_permutation = np.argsort(permutation)
373
+ sentences = [sentences[idx] for idx in permutation]
374
+
375
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
376
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
377
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
378
 
379
  if has_tqdm:
380
  range_iter = trange(
381
  0,
382
+ len(sentences),
383
  batch_size,
384
  desc='Encoding',
385
  disable=not show_progress_bar,
386
  )
387
  else:
388
+ range_iter = range(0, len(sentences), batch_size)
389
 
390
  truncate_dim = truncate_dim or self.config.truncate_dim
 
391
  for i in range_iter:
392
+ encoded_input = self.tokenizer(
393
+ sentences[i : i + batch_size],
394
+ return_tensors='pt',
395
+ **tokenizer_kwargs,
396
+ ).to(self.device)
 
 
 
 
 
 
 
 
 
 
397
 
398
+ embeddings = self.get_text_features(input_ids=encoded_input)
 
 
399
 
400
  if truncate_dim:
401
+ embeddings = self.truncate_embeddings(embeddings, truncate_dim)
402
  if normalize_embeddings:
403
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
404
  if convert_to_numpy:
405
  embeddings = embeddings.cpu()
 
406
  all_embeddings.extend(embeddings)
407
 
408
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
409
 
410
  if convert_to_tensor:
411
  all_embeddings = torch.stack(all_embeddings)
412
  elif convert_to_numpy:
413
+ all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
 
 
414
 
415
+ if input_was_string:
416
  all_embeddings = all_embeddings[0]
417
 
418
+ self.train(is_training)
419
  return all_embeddings
420
 
421
+ def decode_data_image(data_image_str):
422
+ header, data = data_image_str.split(',', 1)
423
+ image_data = base64.b64decode(data)
424
+ return Image.open(BytesIO(image_data))
425
+
426
  @torch.inference_mode()
427
+ def encode_image(
428
  self,
429
+ images: Union[str, List[Union[str, "Image.Image"]]],
 
430
  batch_size: int = 32,
431
  show_progress_bar: Optional[bool] = None,
432
  convert_to_numpy: bool = True,
 
434
  device: Optional[torch.device] = None,
435
  normalize_embeddings: bool = True,
436
  truncate_dim: Optional[int] = None,
 
437
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
438
  """
439
+ Computes image embeddings.
440
+
441
  Args:
442
+ images(`str` or `List[Union[str, Image.Image]]`):
443
+ image paths, URLs, PIL images, or data:image/ strings to be encoded
 
 
 
 
444
  batch_size(`int`, *optional*, defaults to 32):
445
  Batch size for the computation
446
  show_progress_bar(`bool`, *optional*, defaults to None):
447
+ Show a progress bar when encoding images.
448
+ If set to None, progress bar is only shown when
449
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
450
  convert_to_numpy(`bool`, *optional*, defaults to True):
451
+ If true, the output is a list of numpy vectors.
452
+ Else, it is a list of pytorch tensors.
453
  convert_to_tensor(`bool`, *optional*, defaults to False):
454
+ If true, you get one large tensor as return.
455
+ Overwrites any setting from convert_to_numpy
456
  device(`torch.device`, *optional*, defaults to None):
457
  Which torch.device to use for the computation
458
+ normalize_embeddings(`bool`, *optional*, defaults to False):
459
  If set to true, returned vectors will have length 1. In that case,
460
  the faster dot-product (util.dot_score) instead of cosine similarity
461
+ can be used.
462
  truncate_dim(`int`, *optional*, defaults to None):
463
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
 
 
 
464
  Returns:
465
+ By default, a list of tensors is returned.
466
+ If convert_to_tensor, a stacked tensor is returned.
467
+ If convert_to_numpy, a numpy matrix is returned.
468
  """
469
+
470
+ is_training = self.training
471
  self.eval()
472
+
473
+ self.preprocess = self.get_preprocess()
474
  all_embeddings = []
475
+
 
476
  if show_progress_bar is None:
477
  show_progress_bar = (
478
  logger.getEffectiveLevel() == logging.INFO
479
  or logger.getEffectiveLevel() == logging.DEBUG
480
  )
481
+
482
  if convert_to_tensor:
483
  convert_to_numpy = False
484
+
485
+ input_was_single_img = False
486
+ if isinstance(images, str) or not hasattr(images, '__len__'):
487
+ images = [images]
488
+ input_was_single_img = True
489
+
490
  if device is not None:
491
  self.to(device)
492
+
493
+ permutation = np.argsort([-len(str(i)) for i in images])
494
+ inverse_permutation = np.argsort(permutation)
495
+ images = [images[idx] for idx in permutation]
496
+
 
 
 
 
497
  if has_tqdm:
498
  range_iter = trange(
499
  0,
500
+ len(images),
501
  batch_size,
502
  desc='Encoding',
503
  disable=not show_progress_bar,
504
  )
505
  else:
506
+ range_iter = range(0, len(images), batch_size)
 
 
507
 
508
+ from PIL import Image
 
 
509
 
510
+ truncate_dim = truncate_dim or self.config.truncate_dim
511
  for i in range_iter:
512
+ batch_images = images[i:i+batch_size]
513
+ processed_inputs = []
514
+
515
+ for img in batch_images:
516
+ if isinstance(img, str):
517
+ if img.startswith('http'):
518
+ response = requests.get(img)
519
+ image = Image.open(BytesIO(response.content)).convert('RGB')
520
+ elif img.startswith('data:image/'):
521
+ image = decode_data_image(img).convert('RGB')
522
+ else:
523
+ image = Image.open(img).convert('RGB')
524
+ elif isinstance(img, Image.Image):
525
+ image = img.convert('RGB')
526
+ else:
527
+ raise ValueError("Unsupported image format")
528
+
529
+ processed_inputs.append(image)
530
+
531
+ processed_inputs = self.preprocess(processed_inputs)
532
+ processed_inputs = processed_inputs.to(self.device)
533
+ embeddings = self.get_image_features(processed_inputs)
534
+
535
  if truncate_dim:
536
+ embeddings = self.truncate_embeddings(embeddings, truncate_dim)
537
  if normalize_embeddings:
538
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
539
  if convert_to_numpy:
540
  embeddings = embeddings.cpu()
541
  all_embeddings.extend(embeddings)
542
+
543
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
544
+
545
  if convert_to_tensor:
546
  all_embeddings = torch.stack(all_embeddings)
547
  elif convert_to_numpy:
548
+ all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
549
+
550
+ if input_was_single_img:
 
551
  all_embeddings = all_embeddings[0]
552
+
553
+ self.train(is_training)
554
  return all_embeddings
555
 
556
  def forward(
processing_clip.py CHANGED
@@ -72,6 +72,7 @@ class JinaCLIPImageProcessor(BaseImageProcessor):
72
  return output
73
 
74
  def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
 
75
  _transform_needs_rebuild = False
76
  for k, v in kwargs.items():
77
  if k in self._valid_processor_keys:
 
72
  return output
73
 
74
  def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
75
+
76
  _transform_needs_rebuild = False
77
  for k, v in kwargs.items():
78
  if k in self._valid_processor_keys:
rope_embeddings.py CHANGED
@@ -3,6 +3,7 @@
3
  # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
  # --------------------------------------------------------
5
 
 
6
  from math import pi
7
 
8
  import torch
@@ -74,8 +75,10 @@ class VisionRotaryEmbedding(nn.Module):
74
 
75
  freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
76
 
77
- self.register_buffer('freqs_cos', freqs.cos(), persistent=False)
78
- self.register_buffer('freqs_sin', freqs.sin(), persistent=False)
 
 
79
 
80
  def forward(self, t, start_index=0):
81
  rot_dim = self.freqs_cos.shape[-1]
@@ -134,8 +137,10 @@ class VisionRotaryEmbeddingFast(nn.Module):
134
 
135
  self.patch_dropout = patch_dropout
136
 
137
- self.register_buffer('freqs_cos', freqs_cos, persistent=False)
138
- self.register_buffer('freqs_sin', freqs_sin, persistent=False)
 
 
139
 
140
  def forward(self, t, patch_indices_keep=None):
141
  if patch_indices_keep is not None:
 
3
  # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
  # --------------------------------------------------------
5
 
6
+ import logging
7
  from math import pi
8
 
9
  import torch
 
75
 
76
  freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
77
 
78
+ self.register_buffer('freqs_cos', freqs.cos())
79
+ self.register_buffer('freqs_sin', freqs.sin())
80
+
81
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
82
 
83
  def forward(self, t, start_index=0):
84
  rot_dim = self.freqs_cos.shape[-1]
 
137
 
138
  self.patch_dropout = patch_dropout
139
 
140
+ self.register_buffer('freqs_cos', freqs_cos)
141
+ self.register_buffer('freqs_sin', freqs_sin)
142
+
143
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
144
 
145
  def forward(self, t, patch_indices_keep=None):
146
  if patch_indices_keep is not None:
transform.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import random
2
  import warnings
3
  from dataclasses import asdict, dataclass
4
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
5
 
6
  import torch
7
- import torchvision.transforms.functional as f
8
  from torchvision.transforms import (
9
  CenterCrop,
10
  ColorJitter,
@@ -22,93 +23,88 @@ OPENAI_DATASET_MEAN = tuple(OPENAI_CLIP_MEAN)
22
  OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
23
 
24
 
25
- def _setup_size(size, error_msg):
26
- if isinstance(size, int):
27
- return size, size
28
- if isinstance(size, Sequence) and len(size) == 1:
29
- return size[0], size[0]
30
- if len(size) != 2:
31
- raise ValueError(error_msg)
32
- return size
 
33
 
 
 
34
 
35
- def _center_crop_or_pad(
36
- img: torch.Tensor,
37
- output_size: Union[int, Tuple[int, ...], List[int]],
38
- fill: Union[int, Tuple[int]] = 0,
39
- ) -> torch.Tensor:
40
- """
41
- Center crops and/or pads the given image. If the image is torch Tensor, it is
42
- expected to have [..., H, W] shape, where ... means an arbitrary number of leading
43
- dimensions. If image size is smaller than output size along any edge, image is
44
- padded with 0 and then center cropped.
45
- """
46
- if isinstance(output_size, int):
47
- output_size = (output_size, output_size)
48
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
49
- output_size = (output_size[0], output_size[0])
50
 
51
- _, image_height, image_width = f.get_dimensions(img)
52
- crop_height, crop_width = output_size
 
53
 
54
- if crop_width > image_width or crop_height > image_height:
55
- padding_ltrb = [
56
- (crop_width - image_width) // 2 if crop_width > image_width else 0,
57
- (crop_height - image_height) // 2 if crop_height > image_height else 0,
58
- (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
59
- (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
60
- ]
61
- img = f.pad(img, padding_ltrb, fill=fill)
62
- _, image_height, image_width = f.get_dimensions(img)
63
- if crop_width == image_width and crop_height == image_height:
64
- return img
65
 
66
- crop_top = int(round((image_height - crop_height) / 2.0))
67
- crop_left = int(round((image_width - crop_width) / 2.0))
68
- return f.crop(img, crop_top, crop_left, crop_height, crop_width)
69
 
70
 
71
- class _CenterCropOrPad(torch.nn.Module):
72
- """Crops the given image at the center.
73
- If the image is torch Tensor, it is expected
74
- to have [..., H, W] shape, where ... means an arbitrary number of leading
75
- dimensions. If image size is smaller than output size along any edge, image is
76
- padded with 0 and then center cropped.
77
-
78
- Args:
79
- size (sequence or int): Desired output size of the crop. If size is an
80
- int instead of sequence like (h, w), a square crop (size, size) is
81
- made. If provided a sequence of length 1, it will be interpreted as
82
- (size[0], size[0]).
83
  """
 
 
 
 
 
 
 
 
 
 
84
 
85
- def __init__(self, size, fill=0):
86
- super().__init__()
87
- self.size = _setup_size(
88
- size, error_msg='Please provide only two dimensions (h, w) for size.'
89
- )
90
- self.fill = fill
91
 
92
- def forward(self, img):
93
- """
94
- Args:
95
- img (PIL Image or Tensor): Image to be cropped.
96
 
97
- Returns:
98
- PIL Image or Tensor: Cropped image.
99
- """
100
- return _center_crop_or_pad(img, self.size, fill=self.fill)
101
 
102
- def __repr__(self) -> str:
103
- return f'{self.__class__.__name__}(size={self.size})'
 
 
 
 
 
 
 
 
104
 
 
 
 
105
 
106
- def _convert_to_rgb(image):
107
- return image.convert('RGB')
108
 
 
 
 
109
 
110
- class _ResizeKeepRatio:
111
- """Resize while keeping ratio. Copied from timm"""
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def __init__(
114
  self,
@@ -163,9 +159,8 @@ class _ResizeKeepRatio:
163
  ratio_factor[0] / aspect_factor,
164
  ratio_factor[1] * aspect_factor,
165
  )
166
- return [
167
- round(x * factor / ratio) for x, factor in zip(source_size, ratio_factor)
168
- ]
169
 
170
  def __call__(self, img):
171
  """
@@ -185,7 +180,7 @@ class _ResizeKeepRatio:
185
  self.random_aspect_prob,
186
  self.random_aspect_range,
187
  )
188
- img = f.resize(img, size, self.interpolation)
189
  return img
190
 
191
  def __repr__(self):
@@ -195,8 +190,92 @@ class _ResizeKeepRatio:
195
  return format_string
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class _ColorJitter(object):
199
- """Apply color jitter to the PIL image with a specified probability"""
 
 
200
 
201
  def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
202
  assert 0.0 <= p <= 1.0
@@ -213,7 +292,9 @@ class _ColorJitter(object):
213
 
214
 
215
  class _GrayScale(object):
216
- """Apply gray scale to the PIL image with a specified probability"""
 
 
217
 
218
  def __init__(self, p=0.2):
219
  assert 0.0 <= p <= 1.0
@@ -227,20 +308,6 @@ class _GrayScale(object):
227
  return img
228
 
229
 
230
- @dataclass
231
- class AugmentationCfg:
232
- scale: Tuple[float, float] = (0.9, 1.0)
233
- ratio: Optional[Tuple[float, float]] = None
234
- color_jitter: Optional[
235
- Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
236
- ] = None
237
- re_prob: Optional[float] = None
238
- re_count: Optional[int] = None
239
- use_timm: bool = False
240
- color_jitter_prob: float = None
241
- gray_scale_prob: float = None
242
-
243
-
244
  def image_transform(
245
  image_size: Union[int, Tuple[int, int]],
246
  is_train: bool,
@@ -340,10 +407,10 @@ def image_transform(
340
  else:
341
  if resize_mode == 'longest':
342
  transforms = [
343
- _ResizeKeepRatio(
344
  image_size, interpolation=interpolation_mode, longest=1
345
  ),
346
- _CenterCropOrPad(image_size, fill=fill_color),
347
  ]
348
  elif resize_mode == 'squash':
349
  if isinstance(image_size, int):
@@ -361,7 +428,7 @@ def image_transform(
361
  transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
362
  else:
363
  # resize shortest edge to matching target dim for non-square target
364
- transforms = [_ResizeKeepRatio(image_size)]
365
  transforms += [CenterCrop(image_size)]
366
 
367
  transforms.extend(
@@ -372,3 +439,20 @@ def image_transform(
372
  ]
373
  )
374
  return Compose(transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
  import random
3
  import warnings
4
  from dataclasses import asdict, dataclass
5
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
 
7
  import torch
8
+ import torchvision.transforms.functional as F
9
  from torchvision.transforms import (
10
  CenterCrop,
11
  ColorJitter,
 
23
  OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
24
 
25
 
26
+ @dataclass
27
+ class PreprocessCfg:
28
+ size: Union[int, Tuple[int, int]] = 224
29
+ mode: str = 'RGB'
30
+ mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
31
+ std: Tuple[float, ...] = OPENAI_DATASET_STD
32
+ interpolation: str = 'bicubic'
33
+ resize_mode: str = 'shortest'
34
+ fill_color: int = 0
35
 
36
+ def __post_init__(self):
37
+ assert self.mode in ('RGB',)
38
 
39
+ @property
40
+ def num_channels(self):
41
+ return 3
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ @property
44
+ def input_size(self):
45
+ return (self.num_channels,) + (self.size, self.size)
46
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
 
 
49
 
50
 
51
+ def merge_preprocess_dict(
52
+ base: Union[PreprocessCfg, Dict],
53
+ overlay: Dict,
54
+ ):
55
+ """Merge overlay key-value pairs on top of base preprocess cfg or dict.
56
+ Input dicts are filtered based on PreprocessCfg fields.
 
 
 
 
 
 
57
  """
58
+ if isinstance(base, PreprocessCfg):
59
+ base_clean = asdict(base)
60
+ else:
61
+ base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
62
+ if overlay:
63
+ overlay_clean = {
64
+ k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None
65
+ }
66
+ base_clean.update(overlay_clean)
67
+ return base_clean
68
 
 
 
 
 
 
 
69
 
70
+ def merge_preprocess_kwargs(base: Union[PreprocessCfg, Dict], **kwargs):
71
+ return merge_preprocess_dict(base, kwargs)
 
 
72
 
 
 
 
 
73
 
74
+ @dataclass
75
+ class AugmentationCfg:
76
+ scale: Tuple[float, float] = (0.9, 1.0)
77
+ ratio: Optional[Tuple[float, float]] = None
78
+ color_jitter: Optional[
79
+ Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
80
+ ] = None
81
+ re_prob: Optional[float] = None
82
+ re_count: Optional[int] = None
83
+ use_timm: bool = False
84
 
85
+ # params for simclr_jitter_gray
86
+ color_jitter_prob: float = None
87
+ gray_scale_prob: float = None
88
 
 
 
89
 
90
+ def _setup_size(size, error_msg):
91
+ if isinstance(size, numbers.Number):
92
+ return int(size), int(size)
93
 
94
+ if isinstance(size, Sequence) and len(size) == 1:
95
+ return size[0], size[0]
96
+
97
+ if len(size) != 2:
98
+ raise ValueError(error_msg)
99
+
100
+ return size
101
+
102
+
103
+ class ResizeKeepRatio:
104
+ """Resize and Keep Ratio
105
+
106
+ Copy & paste from `timm`
107
+ """
108
 
109
  def __init__(
110
  self,
 
159
  ratio_factor[0] / aspect_factor,
160
  ratio_factor[1] * aspect_factor,
161
  )
162
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
163
+ return size
 
164
 
165
  def __call__(self, img):
166
  """
 
180
  self.random_aspect_prob,
181
  self.random_aspect_range,
182
  )
183
+ img = F.resize(img, size, self.interpolation)
184
  return img
185
 
186
  def __repr__(self):
 
190
  return format_string
191
 
192
 
193
+ def center_crop_or_pad(
194
+ img: torch.Tensor, output_size: List[int], fill=0
195
+ ) -> torch.Tensor:
196
+ """Center crops and/or pads the given image.
197
+ If the image is torch Tensor, it is expected
198
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
199
+ dimensions. If image size is smaller than output size along any edge, image is
200
+ padded with 0 and then center cropped.
201
+
202
+ Args:
203
+ img (PIL Image or Tensor): Image to be cropped.
204
+ output_size (sequence or int): (height, width) of the crop box. If int or
205
+ sequence with single int, it is used for both directions.
206
+ fill (int, Tuple[int]): Padding color
207
+
208
+ Returns:
209
+ PIL Image or Tensor: Cropped image.
210
+ """
211
+ if isinstance(output_size, numbers.Number):
212
+ output_size = (int(output_size), int(output_size))
213
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
214
+ output_size = (output_size[0], output_size[0])
215
+
216
+ _, image_height, image_width = F.get_dimensions(img)
217
+ crop_height, crop_width = output_size
218
+
219
+ if crop_width > image_width or crop_height > image_height:
220
+ padding_ltrb = [
221
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
222
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
223
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
224
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
225
+ ]
226
+ img = F.pad(img, padding_ltrb, fill=fill)
227
+ _, image_height, image_width = F.get_dimensions(img)
228
+ if crop_width == image_width and crop_height == image_height:
229
+ return img
230
+
231
+ crop_top = int(round((image_height - crop_height) / 2.0))
232
+ crop_left = int(round((image_width - crop_width) / 2.0))
233
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
234
+
235
+
236
+ class CenterCropOrPad(torch.nn.Module):
237
+ """Crops the given image at the center.
238
+ If the image is torch Tensor, it is expected
239
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
240
+ dimensions. If image size is smaller than output size along any edge, image is
241
+ padded with 0 and then center cropped.
242
+
243
+ Args:
244
+ size (sequence or int): Desired output size of the crop. If size is an
245
+ int instead of sequence like (h, w), a square crop (size, size) is
246
+ made. If provided a sequence of length 1, it will be interpreted as
247
+ (size[0], size[0]).
248
+ """
249
+
250
+ def __init__(self, size, fill=0):
251
+ super().__init__()
252
+ self.size = _setup_size(
253
+ size, error_msg='Please provide only two dimensions (h, w) for size.'
254
+ )
255
+ self.fill = fill
256
+
257
+ def forward(self, img):
258
+ """
259
+ Args:
260
+ img (PIL Image or Tensor): Image to be cropped.
261
+
262
+ Returns:
263
+ PIL Image or Tensor: Cropped image.
264
+ """
265
+ return center_crop_or_pad(img, self.size, fill=self.fill)
266
+
267
+ def __repr__(self) -> str:
268
+ return f'{self.__class__.__name__}(size={self.size})'
269
+
270
+
271
+ def _convert_to_rgb(image):
272
+ return image.convert('RGB')
273
+
274
+
275
  class _ColorJitter(object):
276
+ """
277
+ Apply Color Jitter to the PIL image with a specified probability.
278
+ """
279
 
280
  def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
281
  assert 0.0 <= p <= 1.0
 
292
 
293
 
294
  class _GrayScale(object):
295
+ """
296
+ Apply Gray Scale to the PIL image with a specified probability.
297
+ """
298
 
299
  def __init__(self, p=0.2):
300
  assert 0.0 <= p <= 1.0
 
308
  return img
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def image_transform(
312
  image_size: Union[int, Tuple[int, int]],
313
  is_train: bool,
 
407
  else:
408
  if resize_mode == 'longest':
409
  transforms = [
410
+ ResizeKeepRatio(
411
  image_size, interpolation=interpolation_mode, longest=1
412
  ),
413
+ CenterCropOrPad(image_size, fill=fill_color),
414
  ]
415
  elif resize_mode == 'squash':
416
  if isinstance(image_size, int):
 
428
  transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
429
  else:
430
  # resize shortest edge to matching target dim for non-square target
431
+ transforms = [ResizeKeepRatio(image_size)]
432
  transforms += [CenterCrop(image_size)]
433
 
434
  transforms.extend(
 
439
  ]
440
  )
441
  return Compose(transforms)
442
+
443
+
444
+ def image_transform_v2(
445
+ cfg: PreprocessCfg,
446
+ is_train: bool,
447
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
448
+ ):
449
+ return image_transform(
450
+ image_size=cfg.size,
451
+ is_train=is_train,
452
+ mean=cfg.mean,
453
+ std=cfg.std,
454
+ interpolation=cfg.interpolation,
455
+ resize_mode=cfg.resize_mode,
456
+ fill_color=cfg.fill_color,
457
+ aug_cfg=aug_cfg,
458
+ )