maviced commited on
Commit
2aa38ae
1 Parent(s): 436b414

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Thanks to Freddy Boulton (https://github.com/freddyaboulton) for helping with this.
3
+ """
4
+
5
+
6
+ import pickle
7
+
8
+ import gradio as gr
9
+ from datasets import load_dataset
10
+ from transformers import AutoModel
11
+
12
+ # `LSH` and `Table` imports are necessary in order for the
13
+ # `lsh.pickle` file to load successfully.
14
+ from similarity_utils import LSH, BuildLSHTable, Table
15
+
16
+ seed = 42
17
+
18
+ # Only runs once when the script is first run.
19
+ with open("lsh.pickle", "rb") as handle:
20
+ loaded_lsh = pickle.load(handle)
21
+
22
+ # Load model for computing embeddings.
23
+ model_ckpt = "yangswei/snacks_classification"
24
+ model = AutoModel.from_pretrained(model_ckpt)
25
+ lsh_builder = BuildLSHTable(model)
26
+ lsh_builder.lsh = loaded_lsh
27
+
28
+ # Candidate images.
29
+ dataset = load_dataset("beans")
30
+ candidate_dataset = dataset["train"].shuffle(seed=seed)
31
+
32
+
33
+ def query(image, top_k):
34
+ results = lsh_builder.query(image)
35
+
36
+ # Should be a list of string file paths for gr.Gallery to work
37
+ images = []
38
+ # List of labels for each image in the gallery
39
+ labels = []
40
+
41
+ candidates = []
42
+
43
+ for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
44
+ if idx == top_k:
45
+ break
46
+ image_id, label = r.split("_")[0], r.split("_")[1]
47
+ candidates.append(candidate_dataset[int(image_id)]["image"])
48
+ labels.append(f"Label: {label}")
49
+
50
+ for i, candidate in enumerate(candidates):
51
+ filename = f"similar_{i}.png"
52
+ candidate.save(filename)
53
+ images.append(filename)
54
+
55
+ # The gallery component can be a list of tuples, where the first element is a path to a file
56
+ # and the second element is an optional caption for that image
57
+ return list(zip(images, labels))
58
+
59
+
60
+ title = "Fetch Similar Beans 🪴"
61
+ description = "This Space demos an image similarity system. You can refer to [this notebook](TODO) to know the details of the system. You can pick any image from the available samples below. On the right hand side, you'll find the similar images returned by the system. The example images have been named with their corresponding integer class labels for easier identification. The fetched images will also have their integer labels tagged so that you can validate the correctness of the results."
62
+
63
+ # You can set the type of gr.Image to be PIL, numpy or str (filepath)
64
+ # Not sure what the best for this demo is.
65
+ gr.Interface(
66
+ query,
67
+ inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
68
+ outputs=gr.Gallery().style(grid=[3], height="auto"),
69
+ # Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
70
+ title=title,
71
+ description=description,
72
+ examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]],
73
+ ).launch()