File size: 3,394 Bytes
f9a8213
de4f74d
815d67b
 
b6bd42c
0f89a2a
232bcf4
2fe3715
3ce4e0e
815d67b
a79b893
de4f74d
4a52775
 
232bcf4
f9a8213
2fe3715
 
 
 
56cb512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de4f74d
815d67b
 
 
 
b6bd42c
 
de4f74d
40b6e7e
381fd53
 
 
 
 
815d67b
de4f74d
82db274
 
de4f74d
82db274
de4f74d
 
 
 
 
 
 
 
 
 
 
 
 
 
815d67b
 
c9fa560
b369fe5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de4f74d
815d67b
de4f74d
56cb512
f9a8213
1f30c37
833a93d
1f30c37
f9a8213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
import cv2
from keras.models import load_model
from keras.models import Model
from datasets import load_dataset
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

autoencoder = load_model("autoencoder_model.keras")
encoded_images = np.load("X_encoded_compressed.npy")
print("Shape of encoded_images:", encoded_images.shape)
print("Sample encoded image:", encoded_images[0])
dataset = load_dataset('eybro/images')

num_clusters = 10  # Choose the number of clusters
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
kmeans.fit(encoded_images)

def find_nearest_neighbors(encoded_images, input_image, top_n=5):
    """
    Find the closest neighbors to the input image in the encoded image space.

    Args:
    encoded_images (np.ndarray): Array of encoded images (shape: (n_samples, n_features)).
    input_image (np.ndarray): The encoded input image (shape: (1, n_features)).
    top_n (int): The number of nearest neighbors to return.

    Returns:
    List of tuples: (index, distance) of the top_n nearest neighbors.
    """
    # Compute pairwise distances
    distances = euclidean_distances(encoded_images, input_image.reshape(1, -1)).flatten()

    # Sort by distance
    nearest_neighbors = np.argsort(distances)[:top_n]
    return [(index, distances[index]) for index in nearest_neighbors]

def get_image(index):
  split = len(dataset["train"])
  if index < split:
    return dataset["train"][index]
  else:
    return dataset["test"][index-split]

def process_image(image):
    img = np.array(image)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  
    img = cv2.resize(img, (64, 64))  
    img = img.astype('float32')  
    img /= 255.0
    img = np.expand_dims(img, axis=0)

    layer_model = Model(inputs=autoencoder.input, outputs=autoencoder.layers[4].output)

    encoded_array = layer_model.predict(img) 

    pooled_array = encoded_array.max(axis=-1)
    return pooled_array  # Shape: (1, n_features)
    
def inference(image):
    """"""
    # input_image = process_image(image)

    input_image = encoded_images[1010]
    
    nearest_neighbors = find_nearest_neighbors(encoded_images, input_image, top_n=5)
    
    # Print the results
    print("Nearest neighbors (index, distance):")
    for neighbor in nearest_neighbors:
        print(neighbor)
    
    top4 = [int(i[0]) for i in nearest_neighbors[:4]]
    print(f"top 4: {top4}")
    
    for i in top4:
      im = get_image(i)
      print(im["label"], im["timestamp"])

    result_image = get_image(top4[0])
    result = f"{result_image['label']} {result_image['timestamp']}"

    n=2
    plt.figure(figsize=(8, 8))
    for i, (image1, image2) in enumerate(zip(top4[:2], top4[2:])):
        ax = plt.subplot(2, n, i + 1)
        image1 = get_image(image1)["image"]
        image2 = get_image(image2)["image"]
    
        plt.imshow(image1)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(image2)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
    return result
    


demo = gr.Interface(fn=inference, 
                    inputs=gr.Image(label='Upload image'),
                    outputs="text")
demo.launch()