clip_demo / clip_inferencing.py
bala1802's picture
Upload 3 files
d4e8957
raw
history blame
2.17 kB
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer
from tqdm.autonotebook import tqdm
import pickle
from clip_model import CLIPModel
from configuration import CFG
import matplotlib.pyplot as plt
import cv2
def load_model(model_path):
model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval()
return model
def load_df():
with open("pickles/valid_df.pkl", 'rb') as file:
valid_df = pickle.load(file)
return valid_df
def load_image_embeddings():
with open("pickles/image_embeddings.pkl", 'rb') as file:
image_embeddings = pickle.load(file)
return image_embeddings
def find_matches(model, image_embeddings, query, image_filenames, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
encoded_query = tokenizer([query])
batch = {
key: torch.tensor(values).to(CFG.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_features = model.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [image_filenames[idx] for idx in indices[::5]]
_, axes = plt.subplots(3, 3, figsize=(10, 10))
for match, ax in zip(matches, axes.flatten()):
image = cv2.imread(f"Images/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ax.imshow(image)
ax.axis("off")
plt.show()
def inference():
valid_df = load_df()
image_embeddings = load_image_embeddings()
find_matches(load_model(model_path="model/best.pt"),
image_embeddings,
query="dogs on the grass",
image_filenames=valid_df['image'].values, n=9)