image-search / app.py
alymbeks's picture
log to console instead of file
a4b8145
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import pickle
import os
import gradio as gr
import zipfile
import logging
logger = logging.getLogger(__name__)
# Load CLIP model
text_model = SentenceTransformer("clip-ViT-B-32-multilingual-v1")
image_model = SentenceTransformer("clip-ViT-B-32")
image_model.parallel_tokenization = False
img_folder = ".\\photos\\"
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
os.makedirs(img_folder, exist_ok=True)
photo_filename = "unsplash-25k-photos.zip"
if not os.path.exists(photo_filename):
util.http_get("http://sbert.net/datasets/" +
photo_filename, photo_filename)
# Extract all images
with zipfile.ZipFile(photo_filename, "r") as zf:
for member in zf.infolist():
zf.extract(member, img_folder)
emb_filename = ".\\unsplash-25k-photos-embeddings.pkl"
if not os.path.exists(emb_filename):
util.http_get(
"http://sbert.net/datasets/unsplash-25k-photos-embeddings.pkl", emb_filename
)
with open(emb_filename, "rb") as fIn:
img_names, img_emb = pickle.load(fIn)
img_folder = ".\\photos\\"
duplicates = util.paraphrase_mining_embeddings(img_emb)
def search_text(query, top_k=1):
""" " Search an image based on the text query.
Args:
query ([string]): [query you want search for]
top_k (int, optional): [Amount of images o return]. Defaults to 1.
Returns:
[list]: [list of images that are related to the query.]
"""
log_query = query.encode("utf-8").decode("utf-8")
logger.warning(f"{log_query}, {top_k}")
# First, we encode the query.
query_emb = text_model.encode([query])
# Then, we use the util.semantic_search function, which computes the cosine-similarity
# between the query embedding and all image embeddings.
# It then returns the top_k highest ranked images, which we output
hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0]
image = []
for hit in hits:
object = Image.open(os.path.join(
".\\photos\\", img_names[hit["corpus_id"]]))
image.append((object, img_names[hit["corpus_id"]]))
return image
iface_search = gr.Interface(
title="Семантический поиск по картинке - d8a.ai",
description="""Демо-версия семантического поиска изображений, использующая
современные алгоритмы искусственного интеллекта для получения высокоточных
результатов поиска. Пользователи могут искать изображения с помощью запросов
на естественном языке и предварительно просматривать результаты.
Это приложение идеально подходит для создателей контента, маркетологов и менеджеров
социальных сетей и обеспечивает более интеллектуальный и интуитивно понятный
способ поиска и управления визуальным контентом.""",
fn=search_text,
allow_flagging="never",
inputs=[
gr.inputs.Textbox(
lines=4,
label="Поисковый текст",
placeholder="Что вы хотите найти?",
default="Горы Кыргызстана",
),
gr.inputs.Slider(minimum=0, maximum=9, default=5,
step=1, label="Количество"),
],
outputs=gr.Gallery(
label="Найденные изображения", show_label=False, elem_id="gallery"
).style(grid=[5], height="auto"),
examples=[[("Горы Кыргызстана"), 5], [("Люди Кыргызстана"), 2],
[("A dog with a ball"), 5]],
)
iface_search.launch()