Spaces:
Runtime error
Runtime error
from sentence_transformers import SentenceTransformer, util | |
from PIL import Image | |
import pickle | |
import os | |
import gradio as gr | |
import zipfile | |
import logging | |
logger = logging.getLogger(__name__) | |
# Load CLIP model | |
text_model = SentenceTransformer("clip-ViT-B-32-multilingual-v1") | |
image_model = SentenceTransformer("clip-ViT-B-32") | |
image_model.parallel_tokenization = False | |
img_folder = ".\\photos\\" | |
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: | |
os.makedirs(img_folder, exist_ok=True) | |
photo_filename = "unsplash-25k-photos.zip" | |
if not os.path.exists(photo_filename): | |
util.http_get("http://sbert.net/datasets/" + | |
photo_filename, photo_filename) | |
# Extract all images | |
with zipfile.ZipFile(photo_filename, "r") as zf: | |
for member in zf.infolist(): | |
zf.extract(member, img_folder) | |
emb_filename = ".\\unsplash-25k-photos-embeddings.pkl" | |
if not os.path.exists(emb_filename): | |
util.http_get( | |
"http://sbert.net/datasets/unsplash-25k-photos-embeddings.pkl", emb_filename | |
) | |
with open(emb_filename, "rb") as fIn: | |
img_names, img_emb = pickle.load(fIn) | |
img_folder = ".\\photos\\" | |
duplicates = util.paraphrase_mining_embeddings(img_emb) | |
def search_text(query, top_k=1): | |
""" " Search an image based on the text query. | |
Args: | |
query ([string]): [query you want search for] | |
top_k (int, optional): [Amount of images o return]. Defaults to 1. | |
Returns: | |
[list]: [list of images that are related to the query.] | |
""" | |
log_query = query.encode("utf-8").decode("utf-8") | |
logger.warning(f"{log_query}, {top_k}") | |
# First, we encode the query. | |
query_emb = text_model.encode([query]) | |
# Then, we use the util.semantic_search function, which computes the cosine-similarity | |
# between the query embedding and all image embeddings. | |
# It then returns the top_k highest ranked images, which we output | |
hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] | |
image = [] | |
for hit in hits: | |
object = Image.open(os.path.join( | |
".\\photos\\", img_names[hit["corpus_id"]])) | |
image.append((object, img_names[hit["corpus_id"]])) | |
return image | |
iface_search = gr.Interface( | |
title="Семантический поиск по картинке - d8a.ai", | |
description="""Демо-версия семантического поиска изображений, использующая | |
современные алгоритмы искусственного интеллекта для получения высокоточных | |
результатов поиска. Пользователи могут искать изображения с помощью запросов | |
на естественном языке и предварительно просматривать результаты. | |
Это приложение идеально подходит для создателей контента, маркетологов и менеджеров | |
социальных сетей и обеспечивает более интеллектуальный и интуитивно понятный | |
способ поиска и управления визуальным контентом.""", | |
fn=search_text, | |
allow_flagging="never", | |
inputs=[ | |
gr.inputs.Textbox( | |
lines=4, | |
label="Поисковый текст", | |
placeholder="Что вы хотите найти?", | |
default="Горы Кыргызстана", | |
), | |
gr.inputs.Slider(minimum=0, maximum=9, default=5, | |
step=1, label="Количество"), | |
], | |
outputs=gr.Gallery( | |
label="Найденные изображения", show_label=False, elem_id="gallery" | |
).style(grid=[5], height="auto"), | |
examples=[[("Горы Кыргызстана"), 5], [("Люди Кыргызстана"), 2], | |
[("A dog with a ball"), 5]], | |
) | |
iface_search.launch() | |