import gradio as gr import numpy as np from sklearn.metrics.pairwise import euclidean_distances import cv2 from keras.models import load_model from keras.models import Model from datasets import load_dataset from sklearn.cluster import KMeans import matplotlib.pyplot as plt autoencoder = load_model("autoencoder_model.keras") encoded_images = np.load("X_encoded_compressed.npy") dataset = load_dataset("eybro/images-split") num_clusters = 10 # Choose the number of clusters kmeans = KMeans(n_clusters=num_clusters, random_state=42) kmeans.fit(encoded_images) def create_url_from_title(title: str, timestamp: int): video_urls = load_dataset("eybro/video_urls") df = video_urls['train'].to_pandas() filtered = df[df['title'] == title] base_url = df["url"][0] return base_url + f"?t={timestamp}s" def find_nearest_neighbors(encoded_images, input_image, top_n=5): """ Find the closest neighbors to the input image in the encoded image space. Args: encoded_images (np.ndarray): Array of encoded images (shape: (n_samples, n_features)). input_image (np.ndarray): The encoded input image (shape: (1, n_features)). top_n (int): The number of nearest neighbors to return. Returns: List of tuples: (index, distance) of the top_n nearest neighbors. """ # Compute pairwise distances distances = euclidean_distances(encoded_images, input_image.reshape(1, -1)).flatten() # Sort by distance nearest_neighbors = np.argsort(distances)[:top_n] return [(index, distances[index]) for index in nearest_neighbors] def get_image(index): split = len(dataset["train"]) if index < split: return dataset["train"][index] else: return dataset["test"][index-split] def process_image(image): img = np.array(image) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (64, 64)) img = img.astype('float32') img /= 255.0 img = np.expand_dims(img, axis=0) layer_model = Model(inputs=autoencoder.input, outputs=autoencoder.layers[4].output) encoded_array = layer_model.predict(img) pooled_array = encoded_array.max(axis=-1) return pooled_array # Shape: (1, n_features) def inference(image): """""" input_image = process_image(image) # input_image = encoded_images[2000] nearest_neighbors = find_nearest_neighbors(encoded_images, input_image, top_n=5) # Print the results print("Nearest neighbors (index, distance):") for neighbor in nearest_neighbors: print(neighbor) top4 = [int(i[0]) for i in nearest_neighbors[:4]] print(f"top 4: {top4}") for i in top4: im = get_image(i) print(im["label"], im["timestamp"]) result_image = get_image(top4[0]) result = f"{result_image['label']} {result_image['timestamp']} \n{create_url_from_title(result_image['label'], result_image['timestamp'])}" n=2 plt.figure(figsize=(8, 8)) for i, (image1, image2) in enumerate(zip(top4[:2], top4[2:])): ax = plt.subplot(2, n, i + 1) image1 = get_image(image1)["image"] image2 = get_image(image2)["image"] plt.imshow(image1) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax = plt.subplot(2, n, i + 1 + n) plt.imshow(image2) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) return result demo = gr.Interface(fn=inference, inputs=gr.Image(label='Upload image'), outputs="text") demo.launch()