|
import os |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
|
|
class CLIPExtractor: |
|
def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not cache_dir: |
|
|
|
cache_dir = "models" |
|
if not os.path.exists(cache_dir) and os.path.exists("../models"): |
|
cache_dir = "../models" |
|
|
|
|
|
self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir) |
|
self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir) |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
def extract_image(self, frame): |
|
|
|
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
images = [image] |
|
|
|
|
|
inputs = self.processor(images=images, return_tensors="pt").to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.model.get_image_features(**inputs) |
|
|
|
ans = outputs.cpu().numpy() |
|
return ans[0] |
|
|
|
def extract_image_from_file(self, file_name): |
|
if not os.path.exists(file_name): |
|
raise FileNotFoundError(f"File {file_name} not found.") |
|
|
|
images = [Image.open(file_name).convert("RGB")] |
|
|
|
|
|
inputs = self.processor(images=images, return_tensors="pt").to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.model.get_image_features(**inputs) |
|
|
|
ans = outputs.cpu().numpy() |
|
return ans[0] |
|
|
|
def extract_text(self, text): |
|
if not isinstance(text, str) or not text: |
|
raise ValueError("Input text should be a non-empty string.") |
|
|
|
|
|
inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.get_text_features(**inputs) |
|
|
|
ans = outputs.cpu().numpy() |
|
return ans[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
clip_extractor = CLIPExtractor() |
|
|
|
sample_image = "images/狐狸.jpg" |
|
|
|
image_feature = clip_extractor.extract_image_from_file(sample_image) |
|
|
|
|
|
|
|
sample_text = "A photo of fox" |
|
text_feature = clip_extractor.extract_text(sample_text) |
|
|
|
|
|
cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature)) |
|
print(cosine_similarity) |