ryparmar commited on
Commit
de9d997
·
1 Parent(s): 973254d

correct paths to raw images

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -22,8 +22,6 @@ from typing import Callable, Dict, List, Tuple
22
  from PIL.Image import Image
23
 
24
  print(__file__)
25
- import fashion_aggregator.fashion_aggregator as fa
26
-
27
 
28
  os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
29
 
@@ -41,21 +39,21 @@ EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl")
41
  RAW_PHOTOS_DIR = "artifacts/raw-photos"
42
 
43
  # Download image embeddings and raw photos
44
- wandb.login(key=os.getenv('wandb'))
45
- api = wandb.Api()
46
- artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
47
- artifact_embeddings.download(EMBEDDINGS_DIR)
48
- artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
49
- artifact_raw_photos.download("artifacts")
50
 
51
- with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
52
- zip_ref.extractall(RAW_PHOTOS_DIR)
53
 
54
 
55
  class TextEncoder:
56
  """Encodes the given text"""
57
 
58
- def __init__(self, model_path='M-CLIP/XLM-Roberta-Large-Vit-B-32'):
59
  self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
60
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
61
 
@@ -69,11 +67,11 @@ class TextEncoder:
69
  class ImageEnoder:
70
  """Encodes the given image"""
71
 
72
- def __init__(self, model_path='clip-ViT-B-32'):
73
  self.model = SentenceTransformer(model_path)
74
 
75
  @torch.no_grad()
76
- def encode(self, image: Image.Image) -> torch.Tensor:
77
  """Predict/infer text embedding for a given query."""
78
  image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
79
  return image_emb
@@ -81,24 +79,28 @@ class ImageEnoder:
81
 
82
  class Retriever:
83
  """Retrieves relevant images for a given text embedding."""
84
-
85
  def __init__(self, image_embeddings_path=None):
86
  self.text_encoder = TextEncoder()
87
  self.image_encoder = ImageEnoder()
88
 
89
- with open(image_embeddings_path, 'rb') as file:
90
- self.image_names, self.image_embeddings = pickle.load(file)
 
 
 
 
91
  print("Images:", len(self.image_names))
92
 
93
  @torch.no_grad()
94
- def predict(self, text_query: str, k: int=10) -> List[Any]:
95
  """Return top-k relevant items for a given embedding"""
96
  query_emb = self.text_encoder.encode(text_query)
97
  relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
98
  return relevant_images
99
 
100
  @torch.no_grad()
101
- def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
102
  """Return top-k relevant images for a given embedding"""
103
  images = self.predict(text_query, k)
104
  paths_and_scores = {"path": [], "score": []}
@@ -155,7 +157,7 @@ class PredictorBackend:
155
  self.url = url
156
  self._predict = self._predict_from_endpoint
157
  else:
158
- model = fa.Retriever()
159
  self._predict = model.predict
160
  self._search_images = model.search_images
161
 
 
22
  from PIL.Image import Image
23
 
24
  print(__file__)
 
 
25
 
26
  os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
27
 
 
39
  RAW_PHOTOS_DIR = "artifacts/raw-photos"
40
 
41
  # Download image embeddings and raw photos
42
+ # wandb.login(key="4b5a23a662b20fdd61f2aeb5032cf56fdce278a4") # os.getenv('wandb')
43
+ # api = wandb.Api()
44
+ # artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
45
+ # artifact_embeddings.download(EMBEDDINGS_DIR)
46
+ # artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
47
+ # artifact_raw_photos.download("artifacts")
48
 
49
+ # with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
50
+ # zip_ref.extractall(RAW_PHOTOS_DIR)
51
 
52
 
53
  class TextEncoder:
54
  """Encodes the given text"""
55
 
56
+ def __init__(self, model_path="M-CLIP/XLM-Roberta-Large-Vit-B-32"):
57
  self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
58
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
59
 
 
67
  class ImageEnoder:
68
  """Encodes the given image"""
69
 
70
+ def __init__(self, model_path="clip-ViT-B-32"):
71
  self.model = SentenceTransformer(model_path)
72
 
73
  @torch.no_grad()
74
+ def encode(self, image: Image) -> torch.Tensor:
75
  """Predict/infer text embedding for a given query."""
76
  image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
77
  return image_emb
 
79
 
80
  class Retriever:
81
  """Retrieves relevant images for a given text embedding."""
82
+
83
  def __init__(self, image_embeddings_path=None):
84
  self.text_encoder = TextEncoder()
85
  self.image_encoder = ImageEnoder()
86
 
87
+ with open(image_embeddings_path, "rb") as file:
88
+ self.image_names, self.image_embeddings = pickle.load(file)
89
+ self.image_names = [
90
+ img_name.replace("fashion-aggregator/fashion_aggregator/data/photos/", "")
91
+ for img_name in self.image_names
92
+ ]
93
  print("Images:", len(self.image_names))
94
 
95
  @torch.no_grad()
96
+ def predict(self, text_query: str, k: int = 10) -> List[Any]:
97
  """Return top-k relevant items for a given embedding"""
98
  query_emb = self.text_encoder.encode(text_query)
99
  relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
100
  return relevant_images
101
 
102
  @torch.no_grad()
103
+ def search_images(self, text_query: str, k: int = 6) -> Dict[str, List[Any]]:
104
  """Return top-k relevant images for a given embedding"""
105
  images = self.predict(text_query, k)
106
  paths_and_scores = {"path": [], "score": []}
 
157
  self.url = url
158
  self._predict = self._predict_from_endpoint
159
  else:
160
+ model = Retriever(image_embeddings_path=EMBEDDINGS_FILE)
161
  self._predict = model.predict
162
  self._search_images = model.search_images
163