Spaces:
Running
Running
Use altndrr/cased model
Browse files- app.py +9 -6
- artifacts/models/databases/.gitkeep +0 -0
- artifacts/models/retrieval/indices.json +0 -3
- src/nn.py +0 -186
- src/retrieval.py +0 -30
- src/transforms.py +0 -506
app.py
CHANGED
@@ -2,8 +2,8 @@ from typing import Optional
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
-
|
6 |
-
from
|
7 |
|
8 |
PAPER_TITLE = "Vocabulary-free Image Classification"
|
9 |
PAPER_DESCRIPTION = """
|
@@ -37,14 +37,17 @@ To assign a label to an image, we:
|
|
37 |
"""
|
38 |
PAPER_URL = "https://arxiv.org/abs/2306.00917"
|
39 |
|
40 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
|
42 |
-
|
|
|
|
|
43 |
|
44 |
|
45 |
def vic(filename: str, alpha: Optional[float] = None):
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
confidences = dict(zip(vocabulary, scores))
|
49 |
|
50 |
return confidences
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import AutoModel, CLIPProcessor
|
7 |
|
8 |
PAPER_TITLE = "Vocabulary-free Image Classification"
|
9 |
PAPER_DESCRIPTION = """
|
|
|
37 |
"""
|
38 |
PAPER_URL = "https://arxiv.org/abs/2306.00917"
|
39 |
|
|
|
40 |
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
model = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(device)
|
43 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
44 |
|
45 |
|
46 |
def vic(filename: str, alpha: Optional[float] = None):
|
47 |
+
images = processor(images=[Image.open(filename)], return_tensors="pt", padding=True)
|
48 |
+
outputs = model(images, alpha=alpha)
|
49 |
+
vocabulary = outputs["vocabularies"][0]
|
50 |
+
scores = outputs["scores"][0]
|
51 |
confidences = dict(zip(vocabulary, scores))
|
52 |
|
53 |
return confidences
|
artifacts/models/databases/.gitkeep
DELETED
File without changes
|
artifacts/models/retrieval/indices.json
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
|
3 |
-
}
|
|
|
|
|
|
|
|
src/nn.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import tarfile
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
import faiss
|
7 |
-
import gdown
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
from PIL import Image
|
11 |
-
from transformers import CLIPModel, CLIPProcessor
|
12 |
-
|
13 |
-
from src.retrieval import ArrowMetadataProvider
|
14 |
-
from src.transforms import TextCompose, default_vocabulary_transforms
|
15 |
-
|
16 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
-
|
18 |
-
|
19 |
-
RETRIEVAL_DATABASES = {
|
20 |
-
"cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
|
21 |
-
}
|
22 |
-
|
23 |
-
|
24 |
-
class CaSED(torch.nn.Module):
|
25 |
-
"""Torch module for Category Search from External Databases (CaSED).
|
26 |
-
|
27 |
-
Args:
|
28 |
-
index_name (str): Name of the faiss index to use.
|
29 |
-
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
|
30 |
-
|
31 |
-
Extra hparams:
|
32 |
-
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
|
33 |
-
artifact_dir (str): Path to the directory where the databases are stored. Defaults to
|
34 |
-
"artifacts/".
|
35 |
-
retrieval_num_results (int): Number of results to return. Defaults to 10.
|
36 |
-
"""
|
37 |
-
|
38 |
-
def __init__(
|
39 |
-
self,
|
40 |
-
index_name: str = "ViT-L-14_CC12M",
|
41 |
-
vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
|
42 |
-
**kwargs,
|
43 |
-
):
|
44 |
-
super().__init__()
|
45 |
-
|
46 |
-
# load CLIP
|
47 |
-
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
|
48 |
-
self.index_name = index_name
|
49 |
-
self.vocabulary_transforms = vocabulary_transforms
|
50 |
-
self.vision_encoder = model.vision_model
|
51 |
-
self.vision_proj = model.visual_projection
|
52 |
-
self.language_encoder = model.text_model
|
53 |
-
self.language_proj = model.text_projection
|
54 |
-
self.logit_scale = model.logit_scale.exp()
|
55 |
-
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
56 |
-
|
57 |
-
# set hparams
|
58 |
-
kwargs["alpha"] = kwargs.get("alpha", 0.5)
|
59 |
-
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
|
60 |
-
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
|
61 |
-
self.hparams = kwargs
|
62 |
-
|
63 |
-
# download databases
|
64 |
-
self.prepare_data()
|
65 |
-
|
66 |
-
# load faiss indices and metadata providers
|
67 |
-
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
|
68 |
-
indices_fp = indices_list_dir / "indices.json"
|
69 |
-
self.indices = json.load(open(indices_fp))
|
70 |
-
self.resources = {}
|
71 |
-
for name, index_fp in self.indices.items():
|
72 |
-
text_index_fp = Path(index_fp) / "text.index"
|
73 |
-
metadata_fp = Path(index_fp) / "metadata/"
|
74 |
-
|
75 |
-
text_index = faiss.read_index(
|
76 |
-
str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
|
77 |
-
)
|
78 |
-
metadata_provider = ArrowMetadataProvider(metadata_fp)
|
79 |
-
|
80 |
-
self.resources[name] = {
|
81 |
-
"device": DEVICE,
|
82 |
-
"model": "ViT-L-14",
|
83 |
-
"text_index": text_index,
|
84 |
-
"metadata_provider": metadata_provider,
|
85 |
-
}
|
86 |
-
|
87 |
-
def prepare_data(self):
|
88 |
-
"""Download data if needed."""
|
89 |
-
databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"
|
90 |
-
|
91 |
-
for name, url in RETRIEVAL_DATABASES.items():
|
92 |
-
database_path = Path(databases_path, name)
|
93 |
-
if database_path.exists():
|
94 |
-
continue
|
95 |
-
|
96 |
-
# download data
|
97 |
-
target_path = Path(databases_path, name + ".tar.gz")
|
98 |
-
try:
|
99 |
-
gdown.download(url, str(target_path), quiet=False)
|
100 |
-
tar = tarfile.open(target_path, "r:gz")
|
101 |
-
tar.extractall(target_path.parent)
|
102 |
-
tar.close()
|
103 |
-
target_path.unlink()
|
104 |
-
except FileNotFoundError:
|
105 |
-
print(f"Could not download {url}.")
|
106 |
-
print(f"Please download it manually and place it in {target_path.parent}.")
|
107 |
-
|
108 |
-
@torch.no_grad()
|
109 |
-
def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
|
110 |
-
# get the index
|
111 |
-
resources = self.resources[self.index_name]
|
112 |
-
text_index = resources["text_index"]
|
113 |
-
metadata_provider = resources["metadata_provider"]
|
114 |
-
|
115 |
-
# query the index
|
116 |
-
sample_z = sample_z.squeeze(0)
|
117 |
-
sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
|
118 |
-
query_input = sample_z.cpu().detach().numpy().tolist()
|
119 |
-
query = np.expand_dims(np.array(query_input).astype("float32"), 0)
|
120 |
-
|
121 |
-
distances, idxs, _ = text_index.search_and_reconstruct(
|
122 |
-
query, self.hparams["retrieval_num_results"]
|
123 |
-
)
|
124 |
-
results = idxs[0]
|
125 |
-
nb_results = np.where(results == -1)[0]
|
126 |
-
nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
|
127 |
-
indices = results[:nb_results]
|
128 |
-
distances = distances[0][:nb_results]
|
129 |
-
|
130 |
-
if len(distances) == 0:
|
131 |
-
return []
|
132 |
-
|
133 |
-
# get the metadata
|
134 |
-
results = []
|
135 |
-
metadata = metadata_provider.get(indices[:20], ["caption"])
|
136 |
-
for key, (d, i) in enumerate(zip(distances, indices)):
|
137 |
-
output = {}
|
138 |
-
meta = None if key + 1 > len(metadata) else metadata[key]
|
139 |
-
if meta is not None:
|
140 |
-
output.update(meta)
|
141 |
-
output["id"] = i.item()
|
142 |
-
output["similarity"] = d.item()
|
143 |
-
results.append(output)
|
144 |
-
|
145 |
-
# get the captions only
|
146 |
-
vocabularies = [result["caption"] for result in results]
|
147 |
-
|
148 |
-
return vocabularies
|
149 |
-
|
150 |
-
@torch.no_grad()
|
151 |
-
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
|
152 |
-
# forward the image
|
153 |
-
image = self.processor(images=Image.open(image_fp), return_tensors="pt")
|
154 |
-
image["pixel_values"] = image["pixel_values"].to(DEVICE)
|
155 |
-
image_z = self.vision_proj(self.vision_encoder(**image)[1])
|
156 |
-
|
157 |
-
# generate a single text embedding from the unfiltered vocabulary
|
158 |
-
vocabulary = self.query_index(image_z)
|
159 |
-
text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
|
160 |
-
text["input_ids"] = text["input_ids"][:, :77].to(DEVICE)
|
161 |
-
text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE)
|
162 |
-
text_z = self.language_encoder(**text)[1]
|
163 |
-
text_z = self.language_proj(text_z)
|
164 |
-
|
165 |
-
# filter the vocabulary, embed it, and get its mean embedding
|
166 |
-
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
|
167 |
-
text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
|
168 |
-
text = {k: v.to(DEVICE) for k, v in text.items()}
|
169 |
-
vocabulary_z = self.language_encoder(**text)[1]
|
170 |
-
vocabulary_z = self.language_proj(vocabulary_z)
|
171 |
-
vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)
|
172 |
-
|
173 |
-
# get the image and text predictions
|
174 |
-
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
|
175 |
-
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
176 |
-
image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
|
177 |
-
text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
|
178 |
-
|
179 |
-
# average the image and text predictions
|
180 |
-
alpha = alpha or self.hparams["alpha"]
|
181 |
-
sample_p = alpha * image_p + (1 - alpha) * text_p
|
182 |
-
|
183 |
-
# get the scores
|
184 |
-
scores = sample_p[0].cpu().tolist()
|
185 |
-
|
186 |
-
return vocabulary, scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/retrieval.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
from typing import Optional
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import pyarrow as pa
|
6 |
-
|
7 |
-
|
8 |
-
class ArrowMetadataProvider:
|
9 |
-
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
|
10 |
-
|
11 |
-
Code taken from: https://github.dev/rom1504/clip-retrieval
|
12 |
-
"""
|
13 |
-
|
14 |
-
def __init__(self, arrow_folder: Path):
|
15 |
-
arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
|
16 |
-
self.table = pa.concat_tables(
|
17 |
-
[
|
18 |
-
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
|
19 |
-
for arrow_file in arrow_files
|
20 |
-
]
|
21 |
-
)
|
22 |
-
|
23 |
-
def get(self, ids: np.ndarray, cols: Optional[list] = None):
|
24 |
-
"""Implement the get method from the arrow metadata provide, get metadata from ids."""
|
25 |
-
if cols is None:
|
26 |
-
cols = self.table.schema.names
|
27 |
-
else:
|
28 |
-
cols = list(set(self.table.schema.names) & set(cols))
|
29 |
-
t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
|
30 |
-
return t.select(cols).to_pandas().to_dict("records")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/transforms.py
DELETED
@@ -1,506 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from abc import ABC, abstractmethod
|
3 |
-
from typing import Any, Optional, Union, cast
|
4 |
-
|
5 |
-
import inflect
|
6 |
-
import nltk
|
7 |
-
import numpy as np
|
8 |
-
import PIL.Image
|
9 |
-
import torch
|
10 |
-
import torchvision.transforms as T
|
11 |
-
import torchvision.transforms.functional as F
|
12 |
-
from flair.data import Sentence
|
13 |
-
from flair.models import SequenceTagger
|
14 |
-
|
15 |
-
__all__ = [
|
16 |
-
"DynamicResize",
|
17 |
-
"DropFileExtensions",
|
18 |
-
"DropNonAlpha",
|
19 |
-
"DropShortWords",
|
20 |
-
"DropSpecialCharacters",
|
21 |
-
"DropTokens",
|
22 |
-
"DropURLs",
|
23 |
-
"DropWords",
|
24 |
-
"FilterPOS",
|
25 |
-
"FrequencyMinWordCount",
|
26 |
-
"FrequencyTopK",
|
27 |
-
"ReplaceSeparators",
|
28 |
-
"ToRGBTensor",
|
29 |
-
"ToLowercase",
|
30 |
-
"ToSingular",
|
31 |
-
]
|
32 |
-
|
33 |
-
|
34 |
-
class BaseTextTransform(ABC):
|
35 |
-
"""Base class for string transforms."""
|
36 |
-
|
37 |
-
@abstractmethod
|
38 |
-
def __call__(self, text: str):
|
39 |
-
raise NotImplementedError
|
40 |
-
|
41 |
-
def __repr__(self) -> str:
|
42 |
-
return f"{self.__class__.__name__}()"
|
43 |
-
|
44 |
-
|
45 |
-
class DynamicResize(T.Resize):
|
46 |
-
"""Resize the input PIL Image to the given size.
|
47 |
-
|
48 |
-
Extends the torchvision Resize transform to dynamically evaluate the second dimension of the
|
49 |
-
output size based on the aspect ratio of the first input image.
|
50 |
-
"""
|
51 |
-
|
52 |
-
def forward(self, img):
|
53 |
-
if isinstance(self.size, int):
|
54 |
-
_, h, w = F.get_dimensions(img)
|
55 |
-
aspect_ratio = w / h
|
56 |
-
side = self.size
|
57 |
-
|
58 |
-
if aspect_ratio < 1.0:
|
59 |
-
self.size = int(side / aspect_ratio), side
|
60 |
-
else:
|
61 |
-
self.size = side, int(side * aspect_ratio)
|
62 |
-
|
63 |
-
return super().forward(img)
|
64 |
-
|
65 |
-
|
66 |
-
class DropFileExtensions(BaseTextTransform):
|
67 |
-
"""Remove file extensions from the input text."""
|
68 |
-
|
69 |
-
def __call__(self, text: str):
|
70 |
-
"""
|
71 |
-
Args:
|
72 |
-
text (str): Text to remove file extensions from.
|
73 |
-
"""
|
74 |
-
text = re.sub(r"\.\w+", "", text)
|
75 |
-
|
76 |
-
return text
|
77 |
-
|
78 |
-
|
79 |
-
class DropNonAlpha(BaseTextTransform):
|
80 |
-
"""Remove non-alpha words from the input text."""
|
81 |
-
|
82 |
-
def __call__(self, text: str):
|
83 |
-
"""
|
84 |
-
Args:
|
85 |
-
text (str): Text to remove non-alpha words from.
|
86 |
-
"""
|
87 |
-
text = re.sub(r"[^a-zA-Z\s]", "", text)
|
88 |
-
|
89 |
-
return text
|
90 |
-
|
91 |
-
|
92 |
-
class DropShortWords(BaseTextTransform):
|
93 |
-
"""Remove short words from the input text.
|
94 |
-
|
95 |
-
Args:
|
96 |
-
min_length (int): Minimum length of words to keep.
|
97 |
-
"""
|
98 |
-
|
99 |
-
def __init__(self, min_length) -> None:
|
100 |
-
super().__init__()
|
101 |
-
self.min_length = min_length
|
102 |
-
|
103 |
-
def __call__(self, text: str):
|
104 |
-
"""
|
105 |
-
Args:
|
106 |
-
text (str): Text to remove short words from.
|
107 |
-
"""
|
108 |
-
text = " ".join([word for word in text.split() if len(word) >= self.min_length])
|
109 |
-
|
110 |
-
return text
|
111 |
-
|
112 |
-
def __repr__(self) -> str:
|
113 |
-
return f"{self.__class__.__name__}(min_length={self.min_length})"
|
114 |
-
|
115 |
-
|
116 |
-
class DropSpecialCharacters(BaseTextTransform):
|
117 |
-
"""Remove special characters from the input text.
|
118 |
-
|
119 |
-
Special characters are defined as any character that is not a word character, whitespace,
|
120 |
-
hyphen, period, apostrophe, or ampersand.
|
121 |
-
"""
|
122 |
-
|
123 |
-
def __call__(self, text: str):
|
124 |
-
"""
|
125 |
-
Args:
|
126 |
-
text (str): Text to remove special characters from.
|
127 |
-
"""
|
128 |
-
text = re.sub(r"[^\w\s\-\.\'\&]", "", text)
|
129 |
-
|
130 |
-
return text
|
131 |
-
|
132 |
-
|
133 |
-
class DropTokens(BaseTextTransform):
|
134 |
-
"""Remove tokens from the input text.
|
135 |
-
|
136 |
-
Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
|
137 |
-
"""
|
138 |
-
|
139 |
-
def __call__(self, text: str):
|
140 |
-
"""
|
141 |
-
Args:
|
142 |
-
text (str): Text to remove tokens from.
|
143 |
-
"""
|
144 |
-
text = re.sub(r"<[^>]+>", "", text)
|
145 |
-
|
146 |
-
return text
|
147 |
-
|
148 |
-
|
149 |
-
class DropURLs(BaseTextTransform):
|
150 |
-
"""Remove URLs from the input text."""
|
151 |
-
|
152 |
-
def __call__(self, text: str):
|
153 |
-
"""
|
154 |
-
Args:
|
155 |
-
text (str): Text to remove URLs from.
|
156 |
-
"""
|
157 |
-
text = re.sub(r"http\S+", "", text)
|
158 |
-
|
159 |
-
return text
|
160 |
-
|
161 |
-
|
162 |
-
class DropWords(BaseTextTransform):
|
163 |
-
"""Remove words from the input text.
|
164 |
-
|
165 |
-
It is case-insensitive and supports singular and plural forms of the words.
|
166 |
-
"""
|
167 |
-
|
168 |
-
def __init__(self, words: list[str]) -> None:
|
169 |
-
super().__init__()
|
170 |
-
self.words = words
|
171 |
-
self.pattern = r"\b(?:{})\b".format("|".join(words))
|
172 |
-
|
173 |
-
def __call__(self, text: str):
|
174 |
-
"""
|
175 |
-
Args:
|
176 |
-
text (str): Text to remove words from.
|
177 |
-
"""
|
178 |
-
text = re.sub(self.pattern, "", text, flags=re.IGNORECASE)
|
179 |
-
|
180 |
-
return text
|
181 |
-
|
182 |
-
def __repr__(self) -> str:
|
183 |
-
return f"{self.__class__.__name__}(pattern={self.pattern})"
|
184 |
-
|
185 |
-
|
186 |
-
class FilterPOS(BaseTextTransform):
|
187 |
-
"""Filter words by POS tags.
|
188 |
-
|
189 |
-
Args:
|
190 |
-
tags (list): List of POS tags to remove.
|
191 |
-
engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk".
|
192 |
-
keep_compound_nouns (bool): Whether to keep composed words. Defaults to True.
|
193 |
-
"""
|
194 |
-
|
195 |
-
def __init__(self, tags: list, engine: str = "nltk", keep_compound_nouns: bool = True) -> None:
|
196 |
-
super().__init__()
|
197 |
-
self.tags = tags
|
198 |
-
self.engine = engine
|
199 |
-
self.keep_compound_nouns = keep_compound_nouns
|
200 |
-
|
201 |
-
if engine == "nltk":
|
202 |
-
nltk.download("averaged_perceptron_tagger", quiet=True)
|
203 |
-
nltk.download("punkt", quiet=True)
|
204 |
-
self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
|
205 |
-
elif engine == "flair":
|
206 |
-
self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
|
207 |
-
|
208 |
-
def __call__(self, text: str):
|
209 |
-
"""
|
210 |
-
Args:
|
211 |
-
text (str): Text to remove words with specific POS tags from.
|
212 |
-
"""
|
213 |
-
if self.engine == "nltk":
|
214 |
-
word_tags = self.tagger(text)
|
215 |
-
text = " ".join([word for word, tag in word_tags if tag not in self.tags])
|
216 |
-
elif self.engine == "flair":
|
217 |
-
sentence = Sentence(text)
|
218 |
-
self.tagger(sentence)
|
219 |
-
text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
|
220 |
-
|
221 |
-
if self.keep_compound_nouns:
|
222 |
-
compound_nouns = []
|
223 |
-
|
224 |
-
if self.engine == "nltk":
|
225 |
-
for i in range(len(word_tags) - 1):
|
226 |
-
if word_tags[i][1] == "NN" and word_tags[i + 1][1] == "NN":
|
227 |
-
# if they are the same word, skip
|
228 |
-
if word_tags[i][0] == word_tags[i + 1][0]:
|
229 |
-
continue
|
230 |
-
|
231 |
-
compound_noun = word_tags[i][0] + "_" + word_tags[i + 1][0]
|
232 |
-
compound_nouns.append(compound_noun)
|
233 |
-
elif self.engine == "flair":
|
234 |
-
for i in range(len(sentence.tokens) - 1):
|
235 |
-
if sentence.tokens[i].tag == "NN" and sentence.tokens[i + 1].tag == "NN":
|
236 |
-
# if they are the same word, skip
|
237 |
-
if sentence.tokens[i].text == sentence.tokens[i + 1].text:
|
238 |
-
continue
|
239 |
-
|
240 |
-
compound_noun = sentence.tokens[i].text + "_" + sentence.tokens[i + 1].text
|
241 |
-
compound_nouns.append(compound_noun)
|
242 |
-
|
243 |
-
text = " ".join([text, " ".join(compound_nouns)])
|
244 |
-
|
245 |
-
return text
|
246 |
-
|
247 |
-
def __repr__(self) -> str:
|
248 |
-
return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})"
|
249 |
-
|
250 |
-
|
251 |
-
class FrequencyMinWordCount(BaseTextTransform):
|
252 |
-
"""Keep only words that occur more than a minimum number of times in the input text.
|
253 |
-
|
254 |
-
If the threshold is too strong and no words pass the threshold, the threshold is reduced to
|
255 |
-
the most frequent word.
|
256 |
-
|
257 |
-
Args:
|
258 |
-
min_count (int): Minimum number of occurrences of a word to keep.
|
259 |
-
"""
|
260 |
-
|
261 |
-
def __init__(self, min_count) -> None:
|
262 |
-
super().__init__()
|
263 |
-
self.min_count = min_count
|
264 |
-
|
265 |
-
def __call__(self, text: str):
|
266 |
-
"""
|
267 |
-
Args:
|
268 |
-
text (str): Text to remove infrequent words from.
|
269 |
-
"""
|
270 |
-
if self.min_count <= 1:
|
271 |
-
return text
|
272 |
-
|
273 |
-
words = text.split()
|
274 |
-
word_counts = {word: words.count(word) for word in words}
|
275 |
-
|
276 |
-
# if nothing passes the threshold, reduce the threshold to the most frequent word
|
277 |
-
max_word_count = max(word_counts.values() or [0])
|
278 |
-
min_count = max_word_count if self.min_count > max_word_count else self.min_count
|
279 |
-
|
280 |
-
text = " ".join([word for word in words if word_counts[word] >= min_count])
|
281 |
-
|
282 |
-
return text
|
283 |
-
|
284 |
-
def __repr__(self) -> str:
|
285 |
-
return f"{self.__class__.__name__}(min_count={self.min_count})"
|
286 |
-
|
287 |
-
|
288 |
-
class FrequencyTopK(BaseTextTransform):
|
289 |
-
"""Keep only the top k most frequent words in the input text.
|
290 |
-
|
291 |
-
In case of a tie, all words with the same count as the last word are kept.
|
292 |
-
|
293 |
-
Args:
|
294 |
-
top_k (int): Number of top words to keep.
|
295 |
-
"""
|
296 |
-
|
297 |
-
def __init__(self, top_k: int) -> None:
|
298 |
-
super().__init__()
|
299 |
-
self.top_k = top_k
|
300 |
-
|
301 |
-
def __call__(self, text: str):
|
302 |
-
"""
|
303 |
-
Args:
|
304 |
-
text (str): Text to remove infrequent words from.
|
305 |
-
"""
|
306 |
-
if self.top_k < 1:
|
307 |
-
return text
|
308 |
-
|
309 |
-
words = text.split()
|
310 |
-
word_counts = {word: words.count(word) for word in words}
|
311 |
-
top_words = sorted(word_counts, key=word_counts.get, reverse=True)
|
312 |
-
|
313 |
-
# in case of a tie, keep all words with the same count
|
314 |
-
top_words = top_words[: self.top_k]
|
315 |
-
top_words = [word for word in top_words if word_counts[word] == word_counts[top_words[-1]]]
|
316 |
-
|
317 |
-
text = " ".join([word for word in words if word in top_words])
|
318 |
-
|
319 |
-
return text
|
320 |
-
|
321 |
-
def __repr__(self) -> str:
|
322 |
-
return f"{self.__class__.__name__}(top_k={self.top_k})"
|
323 |
-
|
324 |
-
|
325 |
-
class ReplaceSeparators(BaseTextTransform):
|
326 |
-
"""Replace underscores and dashes with spaces."""
|
327 |
-
|
328 |
-
def __call__(self, text: str):
|
329 |
-
"""
|
330 |
-
Args:
|
331 |
-
text (str): Text to replace separators in.
|
332 |
-
"""
|
333 |
-
text = re.sub(r"[_\-]", " ", text)
|
334 |
-
|
335 |
-
return text
|
336 |
-
|
337 |
-
def __repr__(self) -> str:
|
338 |
-
return f"{self.__class__.__name__}()"
|
339 |
-
|
340 |
-
|
341 |
-
class RemoveDuplicates(BaseTextTransform):
|
342 |
-
"""Remove duplicate words from the input text."""
|
343 |
-
|
344 |
-
def __call__(self, text: str):
|
345 |
-
"""
|
346 |
-
Args:
|
347 |
-
text (str): Text to remove duplicate words from.
|
348 |
-
"""
|
349 |
-
text = " ".join(list(set(text.split())))
|
350 |
-
|
351 |
-
return text
|
352 |
-
|
353 |
-
|
354 |
-
class TextCompose:
|
355 |
-
"""Compose several transforms together.
|
356 |
-
|
357 |
-
It differs from the torchvision.transforms.Compose class in that it applies the transforms to
|
358 |
-
a string instead of a PIL Image or Tensor. In addition, it automatically join the list of
|
359 |
-
input strings into a single string and splits the output string into a list of words.
|
360 |
-
|
361 |
-
Args:
|
362 |
-
transforms (list): List of transforms to compose.
|
363 |
-
"""
|
364 |
-
|
365 |
-
def __init__(self, transforms: list[BaseTextTransform]) -> None:
|
366 |
-
self.transforms = transforms
|
367 |
-
|
368 |
-
def __call__(self, text: Union[str, list[str]]) -> Any:
|
369 |
-
if isinstance(text, list):
|
370 |
-
text = " ".join(text)
|
371 |
-
|
372 |
-
for t in self.transforms:
|
373 |
-
text = t(text)
|
374 |
-
return text.split()
|
375 |
-
|
376 |
-
def __repr__(self) -> str:
|
377 |
-
format_string = self.__class__.__name__ + "("
|
378 |
-
for t in self.transforms:
|
379 |
-
format_string += "\n"
|
380 |
-
format_string += f" {t}"
|
381 |
-
format_string += "\n)"
|
382 |
-
return format_string
|
383 |
-
|
384 |
-
|
385 |
-
class ToRGBTensor(T.ToTensor):
|
386 |
-
"""Convert a `PIL Image` or `numpy.ndarray` to tensor.
|
387 |
-
|
388 |
-
Compared with the torchvision `ToTensor` transform, it converts images with a single channel to
|
389 |
-
RGB images. In addition, the conversion to tensor is done only if the input is not already a
|
390 |
-
tensor.
|
391 |
-
"""
|
392 |
-
|
393 |
-
def __call__(self, pic: Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
|
394 |
-
"""
|
395 |
-
Args:
|
396 |
-
pic (PIL Image | numpy.ndarray | torch.Tensor): Image to be converted to tensor.
|
397 |
-
"""
|
398 |
-
img = pic if isinstance(pic, torch.Tensor) else F.to_tensor(pic)
|
399 |
-
img = cast(torch.Tensor, img)
|
400 |
-
|
401 |
-
if img.shape[0] == 1:
|
402 |
-
img = img.repeat(3, 1, 1)
|
403 |
-
|
404 |
-
return img
|
405 |
-
|
406 |
-
def __repr__(self) -> str:
|
407 |
-
return f"{self.__class__.__name__}()"
|
408 |
-
|
409 |
-
|
410 |
-
class ToLowercase(BaseTextTransform):
|
411 |
-
"""Convert text to lowercase."""
|
412 |
-
|
413 |
-
def __call__(self, text: str):
|
414 |
-
"""
|
415 |
-
Args:
|
416 |
-
text (str): Text to convert to lowercase.
|
417 |
-
"""
|
418 |
-
text = text.lower()
|
419 |
-
|
420 |
-
return text
|
421 |
-
|
422 |
-
|
423 |
-
class ToSingular(BaseTextTransform):
|
424 |
-
"""Convert plural words to singular form."""
|
425 |
-
|
426 |
-
def __init__(self) -> None:
|
427 |
-
super().__init__()
|
428 |
-
self.transform = inflect.engine().singular_noun
|
429 |
-
|
430 |
-
def __call__(self, text: str):
|
431 |
-
"""
|
432 |
-
Args:
|
433 |
-
text (str): Text to convert to singular form.
|
434 |
-
"""
|
435 |
-
words = text.split()
|
436 |
-
for i, word in enumerate(words):
|
437 |
-
if not word.endswith("s"):
|
438 |
-
continue
|
439 |
-
|
440 |
-
if word[-2:] in ["ss", "us", "is"]:
|
441 |
-
continue
|
442 |
-
|
443 |
-
if word[-3:] in ["ies", "oes"]:
|
444 |
-
continue
|
445 |
-
|
446 |
-
words[i] = self.transform(word) or word
|
447 |
-
|
448 |
-
text = " ".join(words)
|
449 |
-
|
450 |
-
return text
|
451 |
-
|
452 |
-
def __repr__(self) -> str:
|
453 |
-
return f"{self.__class__.__name__}()"
|
454 |
-
|
455 |
-
|
456 |
-
def default_preprocess(size: Optional[int] = None) -> T.Compose:
|
457 |
-
"""Preprocess input images with preprocessing transforms.
|
458 |
-
|
459 |
-
Args:
|
460 |
-
size (int): Size to resize image to.
|
461 |
-
"""
|
462 |
-
transforms = []
|
463 |
-
if size is not None:
|
464 |
-
transforms.append(DynamicResize(size, interpolation=T.InterpolationMode.BICUBIC))
|
465 |
-
transforms.append(ToRGBTensor())
|
466 |
-
transforms = T.Compose(transforms)
|
467 |
-
|
468 |
-
return transforms
|
469 |
-
|
470 |
-
|
471 |
-
def default_vocabulary_transforms() -> TextCompose:
|
472 |
-
"""Preprocess input text with preprocessing transforms."""
|
473 |
-
words_to_drop = [
|
474 |
-
"image",
|
475 |
-
"photo",
|
476 |
-
"picture",
|
477 |
-
"thumbnail",
|
478 |
-
"logo",
|
479 |
-
"symbol",
|
480 |
-
"clipart",
|
481 |
-
"portrait",
|
482 |
-
"painting",
|
483 |
-
"illustration",
|
484 |
-
"icon",
|
485 |
-
"profile",
|
486 |
-
]
|
487 |
-
pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"]
|
488 |
-
|
489 |
-
transforms = []
|
490 |
-
transforms.append(DropTokens())
|
491 |
-
transforms.append(DropURLs())
|
492 |
-
transforms.append(DropSpecialCharacters())
|
493 |
-
transforms.append(DropFileExtensions())
|
494 |
-
transforms.append(ReplaceSeparators())
|
495 |
-
transforms.append(DropShortWords(min_length=3))
|
496 |
-
transforms.append(DropNonAlpha())
|
497 |
-
transforms.append(ToLowercase())
|
498 |
-
transforms.append(ToSingular())
|
499 |
-
transforms.append(DropWords(words=words_to_drop))
|
500 |
-
transforms.append(FrequencyMinWordCount(min_count=2))
|
501 |
-
transforms.append(FilterPOS(tags=pos_tags, engine="flair", keep_compound_nouns=False))
|
502 |
-
transforms.append(RemoveDuplicates())
|
503 |
-
|
504 |
-
transforms = TextCompose(transforms)
|
505 |
-
|
506 |
-
return transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|