|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
from scipy.fftpack import dct |
|
from datasets import load_dataset |
|
from PIL import Image |
|
from multiprocessing import cpu_count |
|
|
|
|
|
def perceptual_hash_color(image): |
|
image = image.convert("RGB") |
|
image = image.resize((32, 32), Image.ANTIALIAS) |
|
image_array = np.asarray(image) |
|
hashes = [] |
|
for i in range(3): |
|
channel = image_array[:, :, i] |
|
dct_coef = dct(dct(channel, axis=0), axis=1) |
|
dct_reduced_coef = dct_coef[:8, :8] |
|
|
|
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:]) |
|
|
|
hashes.append((dct_reduced_coef >= median_coef_val).flatten() * 1) |
|
return np.concatenate(hashes) |
|
|
|
def hamming_distance(array_1, array_2): |
|
return len([1 for el_1, el_2 in zip(array_1, array_2) if el_1 != el_2]) |
|
|
|
def search_closest_examples(hash_refs, img_dataset): |
|
distances = [] |
|
for hash_ref in hash_refs: |
|
distances.extend([hamming_distance(hash_ref, img_dataset[idx]["hash"]) for idx in range(img_dataset.num_rows)]) |
|
closests = [i.item() % len(img_dataset) for i in np.argsort(distances)[:9]] |
|
return closests, [distances[c] for c in closests] |
|
|
|
def find_closest_images(images, img_dataset): |
|
if not isinstance(images, (list, tuple)): |
|
images = [images] |
|
hashes = [perceptual_hash_color(img) for img in images] |
|
closest_idx, distances = search_closest_examples(hashes, img_dataset) |
|
return closest_idx, distances |
|
|
|
def compute_hash_from_image(img): |
|
img = img.convert("L") |
|
img = img.resize((32, 32), Image.ANTIALIAS) |
|
img_array = np.asarray(img) |
|
dct_coef = dct(dct(img_array, axis=0), axis=1) |
|
dct_reduced_coef = dct_coef[:8, :8] |
|
|
|
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:]) |
|
|
|
hash = (dct_reduced_coef >= median_coef_val).flatten() * 1 |
|
return hash |
|
|
|
|
|
def process_dataset(dataset_name, dataset_split, dataset_column_image): |
|
img_dataset = load_dataset(dataset_name)[dataset_split] |
|
|
|
def add_hash(example): |
|
example["hash"] = perceptual_hash_color(example[dataset_column_image]) |
|
return example |
|
|
|
|
|
img_dataset = img_dataset.map(add_hash, num_proc=max(cpu_count() // 2, 1)) |
|
return img_dataset |
|
|
|
|
|
def compute(dataset_name, dataset_split, dataset_column_image, img): |
|
img_dataset = process_dataset(dataset_name, dataset_split, dataset_column_image) |
|
closest_idx, distances = find_closest_images(img, img_dataset) |
|
return [img_dataset[i][dataset_column_image] for i in closest_idx] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Find if your images are in a public dataset!") |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=600): |
|
dataset_name = gr.Textbox(label="Enter the name of a dataset containing images", value="huggan/few-shot-pokemon") |
|
dataset_split = gr.Textbox(label="Enter the split of this dataset to consider", value="train") |
|
dataset_column_image = gr.Textbox(label="Enter the name of the column of this dataset that contains images", value="image") |
|
img = gr.Image(label="Input your image that will be compared against images of the dataset", type="pil") |
|
btn = gr.Button("Find").style(full_width=True) |
|
|
|
with gr.Column(scale=2, min_width=600): |
|
gallery_similar = gr.Gallery(label="similar images") |
|
gallery_similar.style(grid=3) |
|
|
|
event = btn.click(compute, [dataset_name, dataset_split, dataset_column_image, img], gallery_similar) |
|
|
|
|
|
demo.launch() |