import gradio as gr from gradio.flagging import FlaggingCallback from gradio.components import IOComponent from gradio_client import utils as client_utils from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer from sentence_transformers import util import pickle from PIL import Image import os import logging import csv import datetime from pathlib import Path from typing import List, Any class SaveRelevanceCallback(FlaggingCallback): """ Callback to save the image relevance state to a csv file """ def __init__(self): pass def setup(self, components: List[IOComponent], flagging_dir: str | Path): """ This method gets called once at the beginning of the Interface.launch() method. Args: components ([IOComponent]): Set of components that will provide flagged data. flagging_dir (string): typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()). """ self.components = components self.flagging_dir = flagging_dir os.makedirs(flagging_dir, exist_ok=True) logging.info(f"[SaveRelevance]: Flagging directory set to {flagging_dir}") def flag(self, flag_data: List[Any], flag_option: str | None = None, flag_index: int | None = None, username: str | None = None, ) -> int: """ This gets called every time the button is pressed. Args: interface: The Interface object that is being used to launch the flagging interface. flag_data: The data to be flagged. flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. flag_index (optional): The index of the sample that is being flagged. username (optional): The username of the user that is flagging the data, if logged in. Returns: (int): The total number of samples that have been flagged. """ logging.info("[SaveRelevance]: Flagging data...") flagging_dir = self.flagging_dir log_filepath = Path(flagging_dir) / "log.csv" is_new = not Path(log_filepath).exists() headers = ["query", "image directory", "relevance", "username", "timestamp"] csv_data = [] for idx, (component, sample) in enumerate(zip(self.components, flag_data)): save_dir = Path( flagging_dir ) / client_utils.strip_invalid_filename_characters( getattr(component, "label", None) or f"component {idx}" ) if gr.utils.is_update(sample): csv_data.append(str(sample)) else: new_data = component.deserialize(sample, save_dir=save_dir) if sample is not None else "" if new_data and idx == 1: # TO-DO: change this to a more robust way of getting the image name/identifier # This doesn't work - the directory contains all the images in gallery new_data = new_data.split('/')[-1] csv_data.append(new_data) csv_data.append(str(datetime.datetime.now())) with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile: writer = csv.writer(csvfile) if is_new: writer.writerow(gr.utils.sanitize_list_for_csv(headers)) writer.writerow(gr.utils.sanitize_list_for_csv(csv_data)) with open(log_filepath, "r", encoding="utf-8") as csvfile: line_count = len([None for _ in csv.reader(csvfile)]) - 1 logging.info(f"[SaveRelevance]: Saved a total of {line_count} samples to {log_filepath}") return line_count ## Define model model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") examples = [[("Dog in the beach"), 2, 'ghost'], [("Paris during night."), 1, 'ghost'], [("A cute kangaroo"), 5, 'ghost'], [("Dois cachorros"), 2, 'ghost'], [("un homme marchant sur le parc"), 3, 'ghost'], [("et høyt fjell"), 2, 'ghost']] logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S') #Open the precomputed embeddings emb_filename = 'unsplash-25k-photos-embeddings.pkl' with open(emb_filename, 'rb') as fIn: img_names, img_emb = pickle.load(fIn) #print(f'img_emb: {print(img_emb)}') #print(f'img_names: {print(img_names)}') # helper functions def search_text(query, top_k=1): """" Search an image based on the text query. Args: query ([string]): query you want search for top_k (int, optional): Amount of images o return]. Defaults to 1. Returns: [list]: list of images that are related to the query. [list]: list of image embs that are related to the query. """ logging.info(f"[SearchText]: Searching for {query} with top_k={top_k}...") # First, we encode the query. inputs = tokenizer([query], padding=True, return_tensors="pt") query_emb = model.get_text_features(**inputs) # Then, we use the util.semantic_search function, which computes the cosine-similarity # between the query embedding and all image embeddings. # It then returns the top_k highest ranked images, which we output hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] image = [] for hit in hits: #print(img_names[hit['corpus_id']]) object = Image.open(os.path.join( "photos/", img_names[hit['corpus_id']])) image.append(object) # selected_image_embs.append(img_emb[hit['corpus_id']]) #print(f'array length is: {len(image)}') logging.info(f"[SearchText]: Found {len(image)} images.") return image # def select_image(evt: gr.SelectData): # """ Returns the index of the selected image # Argrs: # evt (SelectData): the event we are listening to # Returns: # int: index of the selected image # """ # logging.info(f"[SelectImage]: Selected image {evt.index}.") # return evt.index callback = SaveRelevanceCallback() with gr.Blocks() as demo: # create display gr.Markdown( """ # Text to Image using CLIP Model 📸 My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image. \n This demo is based on assessment for the 🤗 Huggingface course 2. - To use it, simply write which image you are looking for. See the examples section below for more details. - After you submit your query, you will see a gallery of images that are related to your query. - You can select the relevance of each image by using the dropdown menu. --- To-do: - [ ] Add a way to save multiple image-relevance pairs at once. - [ ] Improve image identification in the csv file. """ ) with gr.Row(): with gr.Column(): query = gr.Textbox(lines=4, label="Write what you are looking for in an image...", placeholder="Text Here...") top_k = gr.Slider(0, 5, step=1, label="Top K relevant images to show") username = gr.Textbox(lines=1, label="Input your unique username 👻 ", placeholder="Text username here...") with gr.Column(): gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[3], height="auto") relevance = gr.Dropdown([str(i) for i in range(6)], multiselect=False, label="How relevent is this image to your input text?") with gr.Row(): with gr.Column(): submit_btn = gr.Button("Submit") with gr.Column(): save_btn = gr.Button("Save after you select the relevance of each image") gr.Markdown("## Here are some examples you can use:") gr.Examples(examples, [query, top_k, username]) callback.setup([query, gallery, relevance, username], "flagged") # when user input query and top_k submit_btn.click(search_text, [query, top_k], [gallery]) # image_relevance_state = gr.State(value={}) # selected_index = gr.Number(value=0, visible=False, precision=0) # when user select an image in the gallery # gallery.select(select_image, None, selected_index) # when user select the relevance of the image # relevance.select(fn=select_image_relevance, # inputs=[gallery, selected_index, image_relevance_state], # outputs=image_relevance_state) # when user click save button # we will flag the current query, selected image, relevance, and username save_btn.click(lambda *args: callback.flag(args), [query, gallery, relevance, username], preprocess=False) # gallery_embs = [] gr.Markdown( """ You find more information about this demo on my ✨ github repository [marcelcastrobr](https://github.com/marcelcastrobr/huggingface_course2) """ ) if __name__ == "__main__": demo.launch(debug=True)