|
import torch |
|
import torch.nn.functional as F |
|
from transformers import DistilBertTokenizer |
|
from tqdm.autonotebook import tqdm |
|
import pickle |
|
|
|
from clip_model import CLIPModel |
|
from configuration import CFG |
|
|
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
|
|
def load_model(model_path): |
|
model = CLIPModel().to(CFG.device) |
|
model.load_state_dict(torch.load(model_path, map_location=CFG.device)) |
|
model.eval() |
|
return model |
|
|
|
def load_df(): |
|
with open("pickles/valid_df.pkl", 'rb') as file: |
|
valid_df = pickle.load(file) |
|
return valid_df |
|
|
|
def load_image_embeddings(): |
|
with open("pickles/image_embeddings.pkl", 'rb') as file: |
|
image_embeddings = pickle.load(file) |
|
return image_embeddings |
|
|
|
def find_matches(model, image_embeddings, query, image_filenames, n=9): |
|
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
|
encoded_query = tokenizer([query]) |
|
batch = { |
|
key: torch.tensor(values).to(CFG.device) |
|
for key, values in encoded_query.items() |
|
} |
|
with torch.no_grad(): |
|
text_features = model.text_encoder( |
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
|
) |
|
text_embeddings = model.text_projection(text_features) |
|
|
|
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) |
|
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) |
|
dot_similarity = text_embeddings_n @ image_embeddings_n.T |
|
|
|
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5) |
|
matches = [image_filenames[idx] for idx in indices[::5]] |
|
|
|
_, axes = plt.subplots(3, 3, figsize=(10, 10)) |
|
for match, ax in zip(matches, axes.flatten()): |
|
image = cv2.imread(f"Images/{match}") |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
ax.imshow(image) |
|
ax.axis("off") |
|
|
|
plt.show() |
|
|
|
def inference(query): |
|
valid_df = load_df() |
|
image_embeddings = load_image_embeddings() |
|
find_matches(load_model(model_path="model/best.pt"), |
|
image_embeddings, |
|
query=query, |
|
image_filenames=valid_df['image'].values, n=9) |
|
|