import streamlit as st
import cv2
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt

# Define a function to search for similar images
def search_similar_images(img_path, num_results=10):
    # Load the query image
    query_image = cv2.imread(img_path)

    # Convert the query image to grayscale
    query_image_gray = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY)

    # Resize the query image to a fixed size
    query_image_resized = cv2.resize(query_image_gray, (300, 300))

    # Calculate the histogram of the query image
    query_hist = cv2.calcHist([query_image_resized], [0], None, [256], [0, 256])

    # Normalize the histogram
    query_hist_norm = cv2.normalize(query_hist, query_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)

    # Load all images from the 'images' directory
    images_dir = "images"
    image_files = os.listdir(images_dir)
    images = []
    for image_file in image_files:
        if image_file.endswith(".jpg") or image_file.endswith(".png"):
            image_path = os.path.join(images_dir, image_file)
            image = cv2.imread(image_path)
            images.append(image)

    # Calculate the histograms and similarities for each image
    similarities = []
    for image in images:
        # Convert the image to grayscale
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Resize the image to a fixed size
        image_resized = cv2.resize(image_gray, (300, 300))

        # Calculate the histogram of the image
        image_hist = cv2.calcHist([image_resized], [0], None, [256], [0, 256])

        # Normalize the histogram
        image_hist_norm = cv2.normalize(image_hist, image_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)

        # Calculate the correlation between the histograms
        similarity = cv2.compareHist(query_hist_norm, image_hist_norm, cv2.HISTCMP_CORREL)

        # Add the similarity to the list
        similarities.append(similarity)

    # Get the indices of the top num_results most similar images
    indices = np.argsort(similarities)[-num_results:]

    # Create a list of the top num_results most similar images
    results = []
    for index in indices:
        results.append(images[index])

    # Return the list of results
    return results

def display_results(similar_images):
    # Create a figure with a grid of 10 subplots
    fig, axs = plt.subplots(2, 5, figsize=(10, 4))

    # Loop through the similar images and display each one in a subplot
    for i,img in enumerate(similar_images):
        # Load the image using OpenCV
        # Convert to RGB for displaying with matplotlib
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Calculate the row and column index for the current subplot
        row = i // 5
        col = i % 5

        # Display the image in the current subplot
        axs[row, col].imshow(img)
        axs[row, col].axis("off")
    
    # Adjust the spacing between subplots and display the figure
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    st.pyplot(fig)

def app():
    # Set the page title and icon
    st.set_page_config(page_title="Similar Images Search", page_icon=":mag:")
    # Add a title to the app
    st.title("Find Similar Images")

    # Add a description to the app
    st.markdown("This app allows you to search for similar images.")

    # Allow the user to choose between "Browse a single image" and "Select from test dataset"
    option = st.radio("Select an option:", ("Browse a single image", "Select from test dataset"))

    # If the user selects "Browse a single image"
    if option == "Browse a single image":
        # Allow the user to upload an image
        uploaded_file = st.file_uploader("Choose an image:", type=["jpg", "jpeg", "png"])

        # If the user has uploaded an image
        if uploaded_file is not None:
            # Load the image
            image = Image.open(uploaded_file).convert("RGB")

            # Display the uploaded image
            st.image(image, caption="Uploaded Image", use_column_width=True)

            # Search for similar images
            # print("Uploaded file ==> ",uploaded_file)
            # Save the image to a temporary file
            temp_file_path = "temp.jpg"
            image.save(temp_file_path)

            results = search_similar_images(temp_file_path)

            # Display the results as a grid of images
            st.markdown("---")
            st.subheader("Similar Images")
            display_results(results)

    # If the user selects "Select from test dataset"
    else:
        # Get a list of all test dataset image files
        test_dataset_dir = "images"
        test_dataset_files = os.listdir(test_dataset_dir)
        test_dataset_images = []

        # Display the list of test dataset image files in a sidebar
        st.sidebar.title("Test Dataset")
        for test_dataset_file in test_dataset_files:
            if test_dataset_file.endswith(".jpg") or test_dataset_file.endswith(".png"):
                test_dataset_image_path = os.path.join(test_dataset_dir, test_dataset_file)
                test_dataset_image = Image.open(test_dataset_image_path)
                test_dataset_images.append(test_dataset_image)

                # Display the file name as a text instead of an image
                st.sidebar.write(test_dataset_file)

        # Allow the user to select an image from the test dataset
        selected_image_path = st.sidebar.selectbox("Select an image:", test_dataset_files)

        # Load the selected image
        selected_image_path = os.path.join(test_dataset_dir, selected_image_path)
        selected_image = Image.open(selected_image_path)

        # Display the selected image
        st.image(selected_image, caption="Selected Image", use_column_width=True)

        # Search for similar images
        results = search_similar_images(selected_image_path)

        # Display the results as a grid of images
        if selected_image_path != os.path.join(test_dataset_dir, test_dataset_files[0]):
            st.markdown("---")
            st.subheader("Similar Images")
            display_results(results)
            # for result in results:
            #     st.image(result, caption="Similar Image", use_column_width=True)
            #     print("result==> ",result)

    # Add a footer to the app
    st.markdown("---")
    st.markdown("Created by Pruthul")


app()



####  == Image feature extraction using Resnet == (Failed when run using streamlit) ### 

# from sklearn.neighbors import NearestNeighbors
# from tensorflow.keras.preprocessing.image import load_img, img_to_array
# from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input

# def extract_features(image_path, model):
#     # Load and preprocess the image
#     image = load_img(image_path, target_size=(224, 224))
#     image_array = img_to_array(image)
#     image_array = preprocess_input(image_array)
#     # Extract the features using the ResNet50 model
#     features = model.predict(image_array.reshape(1, 224, 224, 3))
    
#     # Flatten the features and return them as a 1D array
#     features = features.flatten()
#     return features


# # Define a function to search for similar images
# def search_similar_images(img_path, num_results=10):
#     #load feature dictionary
#     with open("features.npy", "rb") as f:
#         image_dict = np.load(f, allow_pickle=True).item()

#     # Fit a nearest neighbor model on the features
#     nn_model = NearestNeighbors(n_neighbors=num_results, metric='cosine')

#     # Convert the dictionary to a matrix of feature vectors
#     features_list = np.array(list(image_dict.values()))

#     # Fit the model to the feature matrix
#     nn_model.fit(features_list)

#     # Define the file path to the test image
#     test_image_path = img_path

#     # Load all images from the 'images' directory
#     images_dir = "images"
#     image_files = os.listdir(images_dir)
#     images = []
#     for image_file in image_files:
#         if image_file.endswith(".jpg") or image_file.endswith(".png"):
#             image_path = os.path.join(images_dir, image_file)
#             image = cv2.imread(image_path)
#             images.append(image)

#     model = ResNet50(weights='imagenet', include_top=False)
#     # Extract features from the test image
#     test_image_features = extract_features(test_image_path, model)

#     # Reshape the test image features to match the shape of the feature vectors
#     test_image_features = test_image_features.reshape(1, -1)

#     # # Find the 10 most similar images to the test image
#     distances, indices = nn_model.kneighbors(test_image_features)

#     # Create a list of the top num_results most similar images
#     results = []
#     for index in indices[0]:
#         results.append(images[index])

#     # Return the list of results
#     return results