sashavor commited on
Commit
9b28e54
·
1 Parent(s): 1ad85b1
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -6,29 +6,27 @@ import gradio as gr
6
  from datasets import load_dataset
7
  from transformers import AutoModel
8
 
9
- # `LSH` and `Table` imports are necessary in order for the
10
- # `lsh.pickle` file to load successfully.
11
- from similarity_utils import LSH, BuildLSHTable, Table
12
 
13
  seed = 42
14
 
15
  # Only runs once when the script is first run.
16
- with open("lsh.pickle", "rb") as handle:
17
- loaded_lsh = pickle.load(handle)
18
 
19
  # Load model for computing embeddings.
20
- model_ckpt = "abhishek/autotrain-butterflies-new-17716425"
21
- model = AutoModel.from_pretrained(model_ckpt)
22
- lsh_builder = BuildLSHTable(model)
23
- lsh_builder.lsh = loaded_lsh
24
 
25
  # Candidate images.
26
  dataset = load_dataset("huggan/inat_butterflies_top10k")
27
- candidate_dataset = dataset["train"].shuffle(seed=seed)
28
 
29
 
30
  def query(image, top_k):
31
- results = lsh_builder.query(image)
 
 
 
32
 
33
  # Should be a list of string file paths for gr.Gallery to work
34
  images = []
 
6
  from datasets import load_dataset
7
  from transformers import AutoModel
8
 
 
 
 
9
 
10
  seed = 42
11
 
12
  # Only runs once when the script is first run.
13
+ with open("index.pickle", "rb") as handle:
14
+ index = pickle.load(handle)
15
 
16
  # Load model for computing embeddings.
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained("abhishek/autotrain-butterflies-new-17716425")
18
+ model = AutoModel.from_pretrained("abhishek/autotrain-butterflies-new-17716425")
 
 
19
 
20
  # Candidate images.
21
  dataset = load_dataset("huggan/inat_butterflies_top10k")
22
+ candidate_dataset = dataset["train"]
23
 
24
 
25
  def query(image, top_k):
26
+ inputs = feature_extractor(image, return_tensors="pt")
27
+ model_output = model(**inputs)
28
+ embedding = model_output.pooler_output
29
+ results = index.query(embedding)
30
 
31
  # Should be a list of string file paths for gr.Gallery to work
32
  images = []