CLIP-Retrieval / app.py
Eun0's picture
Remove unused package
f9767c2
raw
history blame
3.73 kB
import glob
import gradio as gr
import pandas as pd
import faiss
import clip
import torch
title = r"""
<h1 align="center" id="space-title"> πŸ” Search Similar Text/Image in the Dataset</h1>
"""
description = r"""
In this demo, we use [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
<br>
This demo currently supports text search only.
<br>
The content will be updated to include image search once LAION is available.
The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss)
"""
# From local file
# INDEX_DIR = "dataset/diffusiondb/text_index_folder"
# IND = faiss.read_index(f"{INDEX_DIR}/text.index")
# TEXT_LIST = pd.concat(
# pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
# )['caption'].tolist()
# From huggingface dataset
from huggingface_hub import hf_hub_download, snapshot_download
# Download index file
hf_hub_download(
repo_id="Eun02/diffusiondb_faiss_text_index",
filename="text.index",
repo_type="dataset",
local_dir="./",
)
# Download text file
snapshot_download(
repo_id="Eun02/diffusiondb_faiss_text_index",
allow_patterns="*.parquet",
repo_type="dataset",
local_dir="./",
)
# Load index and text data
#root_path = "dataset/diffusiondb/text_index_folder"
root_path = "."
IND = faiss.read_index(f"{root_path}/text.index")
TEXT_LIST = pd.concat(
pd.read_parquet(file) for file in sorted(glob.glob(f"{root_path}/metadata/*.parquet"))
)['caption'].tolist()
# Load CLIP model
device = "cpu"
CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
@torch.inference_mode
def get_emb(text, device="cpu"):
text_tokens = clip.tokenize([text], truncate=True)
text_features = CLIP_MODEL.encode_text(text_tokens.to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
text_embeddings = text_features.cpu().numpy().astype('float32')
return text_embeddings
@torch.inference_mode
def search_text(dataset, top_k, show_score, query_text, device):
if query_text is None or query_text == "":
raise gr.Error("Query text is missing")
text_embeddings = get_emb(query_text, device)
scores, retrieved_texts = IND.search(text_embeddings, top_k)
scores, retrieved_texts = scores[0], retrieved_texts[0]
result_str = ""
for score, ind in zip(scores, retrieved_texts):
item_str = TEXT_LIST[ind].strip()
if item_str == "":
continue
result_str += f"{item_str}"
if show_score:
result_str += f", {score:0.2f}"
result_str += "\n"
# file_name = query_text.replace(" ", "_")
# if show_score:
# file_name += "_score"
file_name = "output"
output_path = f"./{file_name}.txt"
with open(output_path, "w") as f:
f.writelines(result_str)
return result_str, output_path
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
dataset = gr.Dropdown(label="dataset", choices=["DiffusionDB"], value="DiffusionDB")
top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
show_score = gr.Checkbox(label="Show score", value=True)
query_text = gr.Textbox(label="query text")
btn = gr.Button()
with gr.Row():
result_text = gr.Textbox(label="retrieved text", interactive=False)
result_file = gr.File(label="output file")
btn.click(
fn=search_text,
inputs=[dataset, top_k, show_score, query_text],
outputs=[result_text, result_file],
)
demo.launch()