gowitheflow
commited on
Commit
•
42f8a05
1
Parent(s):
2bff784
Update README.md
Browse files
README.md
CHANGED
@@ -36,82 +36,74 @@ The model might not be optimal for further fine-tuning to do other tasks (such a
|
|
36 |
All the training sets involved in our progressive training scheme that we created can be found in tags in meta data. Please refer to the paper for the exact process.
|
37 |
|
38 |
## Inference
|
|
|
39 |
|
40 |
-
First install the package following our Github Repo. Then define the model, the renderer, and other utils.
|
41 |
```python
|
42 |
import torch
|
43 |
from PIL import Image
|
44 |
from pixel import (
|
45 |
AutoConfig,
|
46 |
PangoCairoTextRenderer,
|
|
|
47 |
PIXELForRepresentation,
|
48 |
PoolingMode,
|
49 |
get_attention_mask,
|
50 |
get_transforms,
|
51 |
glue_strip_spaces,
|
52 |
-
resize_model_embeddings
|
53 |
)
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
)
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
attention_mask = torch.stack(examples['attention_mask'])
|
101 |
-
return {
|
102 |
-
'instance': {"pixel_values": pixel_values, "attention_mask": attention_mask},
|
103 |
-
}
|
104 |
```
|
105 |
|
106 |
A minumum example to do inference and similarity computation:
|
107 |
```python
|
108 |
texts = ["I love you","I like you"]
|
109 |
-
|
110 |
-
inputs = preprocess(texts)
|
111 |
-
inputs = image_collator(inputs)
|
112 |
-
inputs = inputs["instance"]
|
113 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
114 |
-
outputs = model(**inputs).logits.detach().cpu()
|
115 |
print(outputs[0] @ outputs[1].T) # just use dot product because the embeddings are normalized automatically in the model class.
|
116 |
#tensor(0.9217)
|
117 |
```
|
|
|
36 |
All the training sets involved in our progressive training scheme that we created can be found in tags in meta data. Please refer to the paper for the exact process.
|
37 |
|
38 |
## Inference
|
39 |
+
First install the package following our Github Repo. Then define our PixelLinguist Class as follow.
|
40 |
|
|
|
41 |
```python
|
42 |
import torch
|
43 |
from PIL import Image
|
44 |
from pixel import (
|
45 |
AutoConfig,
|
46 |
PangoCairoTextRenderer,
|
47 |
+
PIXELForSequenceClassification,
|
48 |
PIXELForRepresentation,
|
49 |
PoolingMode,
|
50 |
get_attention_mask,
|
51 |
get_transforms,
|
52 |
glue_strip_spaces,
|
53 |
+
resize_model_embeddings,
|
54 |
)
|
55 |
+
from tqdm import tqdm
|
56 |
+
|
57 |
+
class PixelLinguist:
|
58 |
+
def __init__(self, model_name, batch_size = 16, max_seq_length = 64,
|
59 |
+
device=None, keep_mlp = False):
|
60 |
+
if device is not None:
|
61 |
+
self.device = device
|
62 |
+
else:
|
63 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
64 |
+
self.config = AutoConfig.from_pretrained(model_name, num_labels=0)
|
65 |
+
self.batch_size = batch_size
|
66 |
+
if keep_mlp == True:
|
67 |
+
self.model = PIXELForSequenceClassification.from_pretrained(
|
68 |
+
model_name,
|
69 |
+
config=self.config,
|
70 |
+
pooling_mode=PoolingMode.from_string("mean"),
|
71 |
+
add_layer_norm=True
|
72 |
+
).to(self.device)
|
73 |
+
else:
|
74 |
+
self.model = PIXELForRepresentation.from_pretrained(
|
75 |
+
model_name,
|
76 |
+
config=self.config,
|
77 |
+
pooling_mode=PoolingMode.from_string("mean"),
|
78 |
+
add_layer_norm=True
|
79 |
+
).to(self.device)
|
80 |
+
self.processor = PangoCairoTextRenderer.from_pretrained(model_name, rgb=False)
|
81 |
+
self.processor.max_seq_length = max_seq_length
|
82 |
+
resize_model_embeddings(self.model, self.processor.max_seq_length)
|
83 |
+
self.transforms = get_transforms(do_resize=True, size=(self.processor.pixels_per_patch, self.processor.pixels_per_patch * self.processor.max_seq_length))
|
84 |
+
|
85 |
+
def preprocess(self, texts):
|
86 |
+
encodings = [self.processor(text=glue_strip_spaces(a)) for a in texts]
|
87 |
+
pixel_values = torch.stack([self.transforms(Image.fromarray(e.pixel_values)) for e in encodings])
|
88 |
+
attention_mask = torch.stack([get_attention_mask(e.num_text_patches, seq_length=self.processor.max_seq_length) for e in encodings])
|
89 |
+
return {'pixel_values': pixel_values, 'attention_mask': attention_mask}
|
90 |
+
|
91 |
+
def encode(self, texts, **kwargs):
|
92 |
+
all_outputs = []
|
93 |
+
for i in tqdm(range(0, len(texts), self.batch_size)):
|
94 |
+
batch_texts = texts[i:i+batch_size]
|
95 |
+
inputs = self.preprocess(batch_texts)
|
96 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
97 |
+
with torch.no_grad():
|
98 |
+
outputs = self.model(**inputs).logits.detach().cpu()
|
99 |
+
all_outputs.append(outputs)
|
100 |
+
return torch.cat(all_outputs, dim=0)
|
|
|
|
|
|
|
|
|
101 |
```
|
102 |
|
103 |
A minumum example to do inference and similarity computation:
|
104 |
```python
|
105 |
texts = ["I love you","I like you"]
|
106 |
+
embeddings = model.encode(texts)
|
|
|
|
|
|
|
|
|
|
|
107 |
print(outputs[0] @ outputs[1].T) # just use dot product because the embeddings are normalized automatically in the model class.
|
108 |
#tensor(0.9217)
|
109 |
```
|