Spaces:
Runtime error
Runtime error
correct paths to raw images
Browse files
app.py
CHANGED
@@ -22,8 +22,6 @@ from typing import Callable, Dict, List, Tuple
|
|
22 |
from PIL.Image import Image
|
23 |
|
24 |
print(__file__)
|
25 |
-
import fashion_aggregator.fashion_aggregator as fa
|
26 |
-
|
27 |
|
28 |
os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
|
29 |
|
@@ -41,21 +39,21 @@ EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl")
|
|
41 |
RAW_PHOTOS_DIR = "artifacts/raw-photos"
|
42 |
|
43 |
# Download image embeddings and raw photos
|
44 |
-
wandb.login(key=os.getenv('wandb')
|
45 |
-
api = wandb.Api()
|
46 |
-
artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
|
47 |
-
artifact_embeddings.download(EMBEDDINGS_DIR)
|
48 |
-
artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
|
49 |
-
artifact_raw_photos.download("artifacts")
|
50 |
|
51 |
-
with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
|
52 |
-
|
53 |
|
54 |
|
55 |
class TextEncoder:
|
56 |
"""Encodes the given text"""
|
57 |
|
58 |
-
def __init__(self, model_path=
|
59 |
self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
|
60 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
61 |
|
@@ -69,11 +67,11 @@ class TextEncoder:
|
|
69 |
class ImageEnoder:
|
70 |
"""Encodes the given image"""
|
71 |
|
72 |
-
def __init__(self, model_path=
|
73 |
self.model = SentenceTransformer(model_path)
|
74 |
|
75 |
@torch.no_grad()
|
76 |
-
def encode(self, image: Image
|
77 |
"""Predict/infer text embedding for a given query."""
|
78 |
image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
|
79 |
return image_emb
|
@@ -81,24 +79,28 @@ class ImageEnoder:
|
|
81 |
|
82 |
class Retriever:
|
83 |
"""Retrieves relevant images for a given text embedding."""
|
84 |
-
|
85 |
def __init__(self, image_embeddings_path=None):
|
86 |
self.text_encoder = TextEncoder()
|
87 |
self.image_encoder = ImageEnoder()
|
88 |
|
89 |
-
with open(image_embeddings_path,
|
90 |
-
self.image_names, self.image_embeddings = pickle.load(file)
|
|
|
|
|
|
|
|
|
91 |
print("Images:", len(self.image_names))
|
92 |
|
93 |
@torch.no_grad()
|
94 |
-
def predict(self, text_query: str, k: int=10) -> List[Any]:
|
95 |
"""Return top-k relevant items for a given embedding"""
|
96 |
query_emb = self.text_encoder.encode(text_query)
|
97 |
relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
|
98 |
return relevant_images
|
99 |
|
100 |
@torch.no_grad()
|
101 |
-
def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
|
102 |
"""Return top-k relevant images for a given embedding"""
|
103 |
images = self.predict(text_query, k)
|
104 |
paths_and_scores = {"path": [], "score": []}
|
@@ -155,7 +157,7 @@ class PredictorBackend:
|
|
155 |
self.url = url
|
156 |
self._predict = self._predict_from_endpoint
|
157 |
else:
|
158 |
-
model =
|
159 |
self._predict = model.predict
|
160 |
self._search_images = model.search_images
|
161 |
|
|
|
22 |
from PIL.Image import Image
|
23 |
|
24 |
print(__file__)
|
|
|
|
|
25 |
|
26 |
os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
|
27 |
|
|
|
39 |
RAW_PHOTOS_DIR = "artifacts/raw-photos"
|
40 |
|
41 |
# Download image embeddings and raw photos
|
42 |
+
# wandb.login(key="4b5a23a662b20fdd61f2aeb5032cf56fdce278a4") # os.getenv('wandb')
|
43 |
+
# api = wandb.Api()
|
44 |
+
# artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
|
45 |
+
# artifact_embeddings.download(EMBEDDINGS_DIR)
|
46 |
+
# artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
|
47 |
+
# artifact_raw_photos.download("artifacts")
|
48 |
|
49 |
+
# with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
|
50 |
+
# zip_ref.extractall(RAW_PHOTOS_DIR)
|
51 |
|
52 |
|
53 |
class TextEncoder:
|
54 |
"""Encodes the given text"""
|
55 |
|
56 |
+
def __init__(self, model_path="M-CLIP/XLM-Roberta-Large-Vit-B-32"):
|
57 |
self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
|
58 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
59 |
|
|
|
67 |
class ImageEnoder:
|
68 |
"""Encodes the given image"""
|
69 |
|
70 |
+
def __init__(self, model_path="clip-ViT-B-32"):
|
71 |
self.model = SentenceTransformer(model_path)
|
72 |
|
73 |
@torch.no_grad()
|
74 |
+
def encode(self, image: Image) -> torch.Tensor:
|
75 |
"""Predict/infer text embedding for a given query."""
|
76 |
image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
|
77 |
return image_emb
|
|
|
79 |
|
80 |
class Retriever:
|
81 |
"""Retrieves relevant images for a given text embedding."""
|
82 |
+
|
83 |
def __init__(self, image_embeddings_path=None):
|
84 |
self.text_encoder = TextEncoder()
|
85 |
self.image_encoder = ImageEnoder()
|
86 |
|
87 |
+
with open(image_embeddings_path, "rb") as file:
|
88 |
+
self.image_names, self.image_embeddings = pickle.load(file)
|
89 |
+
self.image_names = [
|
90 |
+
img_name.replace("fashion-aggregator/fashion_aggregator/data/photos/", "")
|
91 |
+
for img_name in self.image_names
|
92 |
+
]
|
93 |
print("Images:", len(self.image_names))
|
94 |
|
95 |
@torch.no_grad()
|
96 |
+
def predict(self, text_query: str, k: int = 10) -> List[Any]:
|
97 |
"""Return top-k relevant items for a given embedding"""
|
98 |
query_emb = self.text_encoder.encode(text_query)
|
99 |
relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
|
100 |
return relevant_images
|
101 |
|
102 |
@torch.no_grad()
|
103 |
+
def search_images(self, text_query: str, k: int = 6) -> Dict[str, List[Any]]:
|
104 |
"""Return top-k relevant images for a given embedding"""
|
105 |
images = self.predict(text_query, k)
|
106 |
paths_and_scores = {"path": [], "score": []}
|
|
|
157 |
self.url = url
|
158 |
self._predict = self._predict_from_endpoint
|
159 |
else:
|
160 |
+
model = Retriever(image_embeddings_path=EMBEDDINGS_FILE)
|
161 |
self._predict = model.predict
|
162 |
self._search_images = model.search_images
|
163 |
|