Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio.flagging import FlaggingCallback | |
from gradio.components import IOComponent | |
from gradio_client import utils as client_utils | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
from sentence_transformers import util | |
import pickle | |
from PIL import Image | |
import os | |
import logging | |
import csv | |
import datetime | |
from pathlib import Path | |
from typing import List, Any | |
class SaveRelevanceCallback(FlaggingCallback): | |
""" Callback to save the image relevance state to a csv file | |
""" | |
def __init__(self): | |
pass | |
def setup(self, components: List[IOComponent], flagging_dir: str | Path): | |
""" | |
This method gets called once at the beginning of the Interface.launch() method. | |
Args: | |
components ([IOComponent]): Set of components that will provide flagged data. | |
flagging_dir (string): typically containing the path to the directory where the flagging file should be storied | |
(provided as an argument to Interface.__init__()). | |
""" | |
self.components = components | |
self.flagging_dir = flagging_dir | |
os.makedirs(flagging_dir, exist_ok=True) | |
logging.info(f"[SaveRelevance]: Flagging directory set to {flagging_dir}") | |
def flag(self, | |
flag_data: List[Any], | |
flag_option: str | None = None, | |
flag_index: int | None = None, | |
username: str | None = None, | |
) -> int: | |
""" | |
This gets called every time the <flag> button is pressed. | |
Args: | |
interface: The Interface object that is being used to launch the flagging interface. | |
flag_data: The data to be flagged. | |
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
flag_index (optional): The index of the sample that is being flagged. | |
username (optional): The username of the user that is flagging the data, if logged in. | |
Returns: | |
(int): The total number of samples that have been flagged. | |
""" | |
logging.info("[SaveRelevance]: Flagging data...") | |
flagging_dir = self.flagging_dir | |
log_filepath = Path(flagging_dir) / "log.csv" | |
is_new = not Path(log_filepath).exists() | |
headers = ["query", "image directory", "relevance", "username", "timestamp"] | |
csv_data = [] | |
for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
save_dir = Path( | |
flagging_dir | |
) / client_utils.strip_invalid_filename_characters( | |
getattr(component, "label", None) or f"component {idx}" | |
) | |
if gr.utils.is_update(sample): | |
csv_data.append(str(sample)) | |
else: | |
new_data = component.deserialize(sample, save_dir=save_dir) if sample is not None else "" | |
if new_data and idx == 1: | |
# TO-DO: change this to a more robust way of getting the image name/identifier | |
# This doesn't work - the directory contains all the images in gallery | |
new_data = new_data.split('/')[-1] | |
csv_data.append(new_data) | |
csv_data.append(str(datetime.datetime.now())) | |
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile: | |
writer = csv.writer(csvfile) | |
if is_new: | |
writer.writerow(gr.utils.sanitize_list_for_csv(headers)) | |
writer.writerow(gr.utils.sanitize_list_for_csv(csv_data)) | |
with open(log_filepath, "r", encoding="utf-8") as csvfile: | |
line_count = len([None for _ in csv.reader(csvfile)]) - 1 | |
logging.info(f"[SaveRelevance]: Saved a total of {line_count} samples to {log_filepath}") | |
return line_count | |
## Define model | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
examples = [[("Dog in the beach"), 2, 'ghost'], | |
[("Paris during night."), 1, 'ghost'], | |
[("A cute kangaroo"), 5, 'ghost'], | |
[("Dois cachorros"), 2, 'ghost'], | |
[("un homme marchant sur le parc"), 3, 'ghost'], | |
[("et høyt fjell"), 2, 'ghost']] | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S') | |
#Open the precomputed embeddings | |
emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
with open(emb_filename, 'rb') as fIn: | |
img_names, img_emb = pickle.load(fIn) | |
#print(f'img_emb: {print(img_emb)}') | |
#print(f'img_names: {print(img_names)}') | |
# helper functions | |
def search_text(query, top_k=1): | |
"""" Search an image based on the text query. | |
Args: | |
query ([string]): query you want search for | |
top_k (int, optional): Amount of images o return]. Defaults to 1. | |
Returns: | |
[list]: list of images that are related to the query. | |
[list]: list of image embs that are related to the query. | |
""" | |
logging.info(f"[SearchText]: Searching for {query} with top_k={top_k}...") | |
# First, we encode the query. | |
inputs = tokenizer([query], padding=True, return_tensors="pt") | |
query_emb = model.get_text_features(**inputs) | |
# Then, we use the util.semantic_search function, which computes the cosine-similarity | |
# between the query embedding and all image embeddings. | |
# It then returns the top_k highest ranked images, which we output | |
hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] | |
image = [] | |
for hit in hits: | |
#print(img_names[hit['corpus_id']]) | |
object = Image.open(os.path.join( | |
"photos/", img_names[hit['corpus_id']])) | |
image.append(object) | |
# selected_image_embs.append(img_emb[hit['corpus_id']]) | |
#print(f'array length is: {len(image)}') | |
logging.info(f"[SearchText]: Found {len(image)} images.") | |
return image | |
# def select_image(evt: gr.SelectData): | |
# """ Returns the index of the selected image | |
# Argrs: | |
# evt (SelectData): the event we are listening to | |
# Returns: | |
# int: index of the selected image | |
# """ | |
# logging.info(f"[SelectImage]: Selected image {evt.index}.") | |
# return evt.index | |
callback = SaveRelevanceCallback() | |
with gr.Blocks() as demo: | |
# create display | |
gr.Markdown( | |
""" | |
# Text to Image using CLIP Model 📸 | |
My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image. \n | |
This demo is based on assessment for the 🤗 Huggingface course 2. | |
- To use it, simply write which image you are looking for. See the examples section below for more details. | |
- After you submit your query, you will see a gallery of images that are related to your query. | |
- You can select the relevance of each image by using the dropdown menu. | |
--- | |
To-do: | |
- [ ] Add a way to save multiple image-relevance pairs at once. | |
- [ ] Improve image identification in the csv file. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
query = gr.Textbox(lines=4, | |
label="Write what you are looking for in an image...", | |
placeholder="Text Here...") | |
top_k = gr.Slider(0, 5, step=1, label="Top K relevant images to show") | |
username = gr.Textbox(lines=1, label="Input your unique username 👻 ", placeholder="Text username here...") | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[3], height="auto") | |
relevance = gr.Dropdown([str(i) for i in range(6)], multiselect=False, | |
label="How relevent is this image to your input text?") | |
with gr.Row(): | |
with gr.Column(): | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
save_btn = gr.Button("Save after you select the relevance of each image") | |
gr.Markdown("## Here are some examples you can use:") | |
gr.Examples(examples, [query, top_k, username]) | |
callback.setup([query, gallery, relevance, username], "flagged") | |
# when user input query and top_k | |
submit_btn.click(search_text, [query, top_k], [gallery]) | |
# image_relevance_state = gr.State(value={}) | |
# selected_index = gr.Number(value=0, visible=False, precision=0) | |
# when user select an image in the gallery | |
# gallery.select(select_image, None, selected_index) | |
# when user select the relevance of the image | |
# relevance.select(fn=select_image_relevance, | |
# inputs=[gallery, selected_index, image_relevance_state], | |
# outputs=image_relevance_state) | |
# when user click save button | |
# we will flag the current query, selected image, relevance, and username | |
save_btn.click(lambda *args: callback.flag(args), [query, gallery, relevance, username], preprocess=False) | |
# gallery_embs = [] | |
gr.Markdown( | |
""" | |
You find more information about this demo on my ✨ github repository [marcelcastrobr](https://github.com/marcelcastrobr/huggingface_course2) | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |