Spaces:
Runtime error
Runtime error
import glob | |
import gradio as gr | |
import pandas as pd | |
import faiss | |
import clip | |
import torch | |
title = r""" | |
<h1 align="center" id="space-title"> π Search Similar Text/Image in the Dataset</h1> | |
""" | |
description = r""" | |
In this demo, we use [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available. | |
<br> | |
This demo currently supports text search only. | |
<br> | |
The content will be updated to include image search once LAION is available. | |
The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss) | |
""" | |
# From local file | |
# INDEX_DIR = "dataset/diffusiondb/text_index_folder" | |
# IND = faiss.read_index(f"{INDEX_DIR}/text.index") | |
# TEXT_LIST = pd.concat( | |
# pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet") | |
# )['caption'].tolist() | |
# From huggingface dataset | |
from huggingface_hub import hf_hub_download, snapshot_download | |
# Download index file | |
hf_hub_download( | |
repo_id="Eun02/diffusiondb_faiss_text_index", | |
filename="text.index", | |
repo_type="dataset", | |
local_dir="./", | |
) | |
# Download text file | |
snapshot_download( | |
repo_id="Eun02/diffusiondb_faiss_text_index", | |
allow_patterns="*.parquet", | |
repo_type="dataset", | |
local_dir="./", | |
) | |
# Load index and text data | |
#root_path = "dataset/diffusiondb/text_index_folder" | |
root_path = "." | |
IND = faiss.read_index(f"{root_path}/text.index") | |
TEXT_LIST = pd.concat( | |
pd.read_parquet(file) for file in sorted(glob.glob(f"{root_path}/metadata/*.parquet")) | |
)['caption'].tolist() | |
# Load CLIP model | |
device = "cpu" | |
CLIP_MODEL, _ = clip.load("ViT-B/32", device=device) | |
def get_emb(text, device="cpu"): | |
text_tokens = clip.tokenize([text], truncate=True) | |
text_features = CLIP_MODEL.encode_text(text_tokens.to(device)) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
text_embeddings = text_features.cpu().numpy().astype('float32') | |
return text_embeddings | |
def search_text(dataset, top_k, show_score, query_text, device): | |
if query_text is None or query_text == "": | |
raise gr.Error("Query text is missing") | |
text_embeddings = get_emb(query_text, device) | |
scores, retrieved_texts = IND.search(text_embeddings, top_k) | |
scores, retrieved_texts = scores[0], retrieved_texts[0] | |
result_str = "" | |
for score, ind in zip(scores, retrieved_texts): | |
item_str = TEXT_LIST[ind].strip() | |
if item_str == "": | |
continue | |
result_str += f"{item_str}" | |
if show_score: | |
result_str += f", {score:0.2f}" | |
result_str += "\n" | |
# file_name = query_text.replace(" ", "_") | |
# if show_score: | |
# file_name += "_score" | |
file_name = "output" | |
output_path = f"./{file_name}.txt" | |
with open(output_path, "w") as f: | |
f.writelines(result_str) | |
return result_str, output_path | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
dataset = gr.Dropdown(label="dataset", choices=["DiffusionDB"], value="DiffusionDB") | |
top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8) | |
show_score = gr.Checkbox(label="Show score", value=True) | |
query_text = gr.Textbox(label="query text") | |
btn = gr.Button() | |
with gr.Row(): | |
result_text = gr.Textbox(label="retrieved text", interactive=False) | |
result_file = gr.File(label="output file") | |
btn.click( | |
fn=search_text, | |
inputs=[dataset, top_k, show_score, query_text], | |
outputs=[result_text, result_file], | |
) | |
demo.launch() | |