gmastrapas
commited on
Commit
·
d220929
1
Parent(s):
d956937
fix: bug in custom_st.py
Browse files- config_sentence_transformers.json +3 -3
- custom_st.py +43 -43
config_sentence_transformers.json
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
{
|
2 |
"__version__": {
|
3 |
-
"sentence_transformers": "3.
|
4 |
-
"transformers": "4.
|
5 |
-
"pytorch": "2.
|
6 |
},
|
7 |
"prompts": {},
|
8 |
"default_prompt_name": null,
|
|
|
1 |
{
|
2 |
"__version__": {
|
3 |
+
"sentence_transformers": "3.3.0",
|
4 |
+
"transformers": "4.46.2",
|
5 |
+
"pytorch": "2.2.2"
|
6 |
},
|
7 |
"prompts": {},
|
8 |
"default_prompt_name": null,
|
custom_st.py
CHANGED
@@ -34,8 +34,8 @@ class Transformer(nn.Module):
|
|
34 |
self.model = AutoModel.from_pretrained(
|
35 |
model_name_or_path, config=config, **model_kwargs
|
36 |
)
|
37 |
-
if max_seq_length is not None and
|
38 |
-
tokenizer_kwargs[
|
39 |
|
40 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
41 |
tokenizer_name_or_path or model_name_or_path,
|
@@ -49,9 +49,9 @@ class Transformer(nn.Module):
|
|
49 |
# No max_seq_length set. Try to infer from model
|
50 |
if max_seq_length is None:
|
51 |
if (
|
52 |
-
hasattr(self.model,
|
53 |
-
and hasattr(self.model.config,
|
54 |
-
and hasattr(self.tokenizer,
|
55 |
):
|
56 |
max_seq_length = min(
|
57 |
self.model.config.max_position_embeddings,
|
@@ -63,7 +63,7 @@ class Transformer(nn.Module):
|
|
63 |
|
64 |
@staticmethod
|
65 |
def _decode_data_image(data_image_str: str) -> Image.Image:
|
66 |
-
header, data = data_image_str.split(
|
67 |
image_data = base64.b64decode(data)
|
68 |
return Image.open(BytesIO(image_data))
|
69 |
|
@@ -79,62 +79,62 @@ class Transformer(nn.Module):
|
|
79 |
_image_or_text_descriptors = []
|
80 |
for sample in texts:
|
81 |
if isinstance(sample, str):
|
82 |
-
if sample.startswith(
|
83 |
response = requests.get(sample)
|
84 |
-
_images.append(Image.open(BytesIO(response.content)).convert(
|
85 |
_image_or_text_descriptors.append(0)
|
86 |
-
elif sample.startswith(
|
87 |
-
_images.append(self._decode_data_image(sample).convert(
|
88 |
_image_or_text_descriptors.append(0)
|
89 |
else:
|
90 |
try:
|
91 |
-
_images.append(Image.open(sample).convert(
|
92 |
_image_or_text_descriptors.append(0)
|
93 |
except Exception as e:
|
94 |
_ = str(e)
|
95 |
_texts.append(sample)
|
96 |
_image_or_text_descriptors.append(1)
|
97 |
elif isinstance(sample, Image.Image):
|
98 |
-
_images.append(sample.convert(
|
99 |
_image_or_text_descriptors.append(0)
|
100 |
|
101 |
encoding = {}
|
102 |
if len(_texts):
|
103 |
-
encoding[
|
104 |
-
|
105 |
padding=padding,
|
106 |
-
truncation=
|
107 |
-
return_tensors=
|
108 |
max_length=self.max_seq_length,
|
109 |
).input_ids
|
110 |
|
111 |
if len(_images):
|
112 |
-
encoding[
|
113 |
-
_images, return_tensors=
|
114 |
).pixel_values
|
115 |
|
116 |
-
encoding[
|
117 |
return encoding
|
118 |
|
119 |
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
120 |
image_embeddings = []
|
121 |
text_embeddings = []
|
122 |
|
123 |
-
if
|
124 |
-
image_embeddings = self.model.get_image_features(features[
|
125 |
-
if
|
126 |
-
text_embeddings = self.model.get_text_features(features[
|
127 |
|
128 |
sentence_embedding = []
|
129 |
image_features = iter(image_embeddings)
|
130 |
text_features = iter(text_embeddings)
|
131 |
-
for _, _input_type in enumerate(features[
|
132 |
if _input_type == 0:
|
133 |
sentence_embedding.append(next(image_features))
|
134 |
else:
|
135 |
sentence_embedding.append(next(text_features))
|
136 |
|
137 |
-
features[
|
138 |
return features
|
139 |
|
140 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
@@ -143,16 +143,16 @@ class Transformer(nn.Module):
|
|
143 |
self.image_processor.save_pretrained(output_path)
|
144 |
|
145 |
@staticmethod
|
146 |
-
def load(input_path: str) ->
|
147 |
# Old classes used other config names than 'sentence_bert_config.json'
|
148 |
for config_name in [
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
]:
|
157 |
sbert_config_path = os.path.join(input_path, config_name)
|
158 |
if os.path.exists(sbert_config_path):
|
@@ -162,19 +162,19 @@ class Transformer(nn.Module):
|
|
162 |
config = json.load(fIn)
|
163 |
|
164 |
# Don't allow configs to set trust_remote_code
|
165 |
-
if
|
166 |
-
config[
|
167 |
-
if
|
168 |
-
config[
|
169 |
if (
|
170 |
-
|
171 |
-
and
|
172 |
):
|
173 |
-
config[
|
174 |
if (
|
175 |
-
|
176 |
-
and
|
177 |
):
|
178 |
-
config[
|
179 |
|
180 |
return Transformer(model_name_or_path=input_path, **config)
|
|
|
34 |
self.model = AutoModel.from_pretrained(
|
35 |
model_name_or_path, config=config, **model_kwargs
|
36 |
)
|
37 |
+
if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
|
38 |
+
tokenizer_kwargs['model_max_length'] = max_seq_length
|
39 |
|
40 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
41 |
tokenizer_name_or_path or model_name_or_path,
|
|
|
49 |
# No max_seq_length set. Try to infer from model
|
50 |
if max_seq_length is None:
|
51 |
if (
|
52 |
+
hasattr(self.model, 'config')
|
53 |
+
and hasattr(self.model.config, 'max_position_embeddings')
|
54 |
+
and hasattr(self.tokenizer, 'model_max_length')
|
55 |
):
|
56 |
max_seq_length = min(
|
57 |
self.model.config.max_position_embeddings,
|
|
|
63 |
|
64 |
@staticmethod
|
65 |
def _decode_data_image(data_image_str: str) -> Image.Image:
|
66 |
+
header, data = data_image_str.split(',', 1)
|
67 |
image_data = base64.b64decode(data)
|
68 |
return Image.open(BytesIO(image_data))
|
69 |
|
|
|
79 |
_image_or_text_descriptors = []
|
80 |
for sample in texts:
|
81 |
if isinstance(sample, str):
|
82 |
+
if sample.startswith('http'):
|
83 |
response = requests.get(sample)
|
84 |
+
_images.append(Image.open(BytesIO(response.content)).convert('RGB'))
|
85 |
_image_or_text_descriptors.append(0)
|
86 |
+
elif sample.startswith('data:image/'):
|
87 |
+
_images.append(self._decode_data_image(sample).convert('RGB'))
|
88 |
_image_or_text_descriptors.append(0)
|
89 |
else:
|
90 |
try:
|
91 |
+
_images.append(Image.open(sample).convert('RGB'))
|
92 |
_image_or_text_descriptors.append(0)
|
93 |
except Exception as e:
|
94 |
_ = str(e)
|
95 |
_texts.append(sample)
|
96 |
_image_or_text_descriptors.append(1)
|
97 |
elif isinstance(sample, Image.Image):
|
98 |
+
_images.append(sample.convert('RGB'))
|
99 |
_image_or_text_descriptors.append(0)
|
100 |
|
101 |
encoding = {}
|
102 |
if len(_texts):
|
103 |
+
encoding['input_ids'] = self.tokenizer(
|
104 |
+
_texts,
|
105 |
padding=padding,
|
106 |
+
truncation='longest_first',
|
107 |
+
return_tensors='pt',
|
108 |
max_length=self.max_seq_length,
|
109 |
).input_ids
|
110 |
|
111 |
if len(_images):
|
112 |
+
encoding['pixel_values'] = self.image_processor(
|
113 |
+
_images, return_tensors='pt'
|
114 |
).pixel_values
|
115 |
|
116 |
+
encoding['image_text_info'] = _image_or_text_descriptors
|
117 |
return encoding
|
118 |
|
119 |
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
120 |
image_embeddings = []
|
121 |
text_embeddings = []
|
122 |
|
123 |
+
if 'pixel_values' in features:
|
124 |
+
image_embeddings = self.model.get_image_features(features['pixel_values'])
|
125 |
+
if 'input_ids' in features:
|
126 |
+
text_embeddings = self.model.get_text_features(features['input_ids'])
|
127 |
|
128 |
sentence_embedding = []
|
129 |
image_features = iter(image_embeddings)
|
130 |
text_features = iter(text_embeddings)
|
131 |
+
for _, _input_type in enumerate(features['image_text_info']):
|
132 |
if _input_type == 0:
|
133 |
sentence_embedding.append(next(image_features))
|
134 |
else:
|
135 |
sentence_embedding.append(next(text_features))
|
136 |
|
137 |
+
features['sentence_embedding'] = torch.stack(sentence_embedding).float()
|
138 |
return features
|
139 |
|
140 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
|
|
143 |
self.image_processor.save_pretrained(output_path)
|
144 |
|
145 |
@staticmethod
|
146 |
+
def load(input_path: str) -> 'Transformer':
|
147 |
# Old classes used other config names than 'sentence_bert_config.json'
|
148 |
for config_name in [
|
149 |
+
'sentence_bert_config.json',
|
150 |
+
'sentence_roberta_config.json',
|
151 |
+
'sentence_distilbert_config.json',
|
152 |
+
'sentence_camembert_config.json',
|
153 |
+
'sentence_albert_config.json',
|
154 |
+
'sentence_xlm-roberta_config.json',
|
155 |
+
'sentence_xlnet_config.json',
|
156 |
]:
|
157 |
sbert_config_path = os.path.join(input_path, config_name)
|
158 |
if os.path.exists(sbert_config_path):
|
|
|
162 |
config = json.load(fIn)
|
163 |
|
164 |
# Don't allow configs to set trust_remote_code
|
165 |
+
if 'config_kwargs' in config and 'trust_remote_code' in config['config_kwargs']:
|
166 |
+
config['config_kwargs'].pop('trust_remote_code')
|
167 |
+
if 'model_kwargs' in config and 'trust_remote_code' in config['model_kwargs']:
|
168 |
+
config['model_kwargs'].pop('trust_remote_code')
|
169 |
if (
|
170 |
+
'tokenizer_kwargs' in config
|
171 |
+
and 'trust_remote_code' in config['tokenizer_kwargs']
|
172 |
):
|
173 |
+
config['tokenizer_kwargs'].pop('trust_remote_code')
|
174 |
if (
|
175 |
+
'image_processor_kwargs' in config
|
176 |
+
and 'trust_remote_code' in config['image_processor_kwargs']
|
177 |
):
|
178 |
+
config['image_processor_kwargs'].pop('trust_remote_code')
|
179 |
|
180 |
return Transformer(model_name_or_path=input_path, **config)
|