# -*- coding: utf-8 -*-
"""Text_to_Image_Demo.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1mkGloXbrNHKFh99ryB6PQDyCJ3u4RqD5

## Generate Images from Text
"""
# Important installations
# pip install openai
# pip install python-dotenv
# pip install transformers datasets -q
# pip install streamlit

import os
import openai

# open_ai_key_file = "openai_api_key_llm_2023.txt"
# with open(open_ai_key_file, "r") as f:
#   for line in f:
#     OPENAI_KEY = line
#     break

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# Read 100 flower names from 100flowers.txt
# openai.api_key  = OPENAI_KEY
file1 = open('./100flowers.txt', 'r')
Lines = file1.readlines()
Lines = [line.strip() for line in Lines]

from openai import OpenAI
from PIL import Image
import urllib.request
from io import BytesIO
from IPython.display import display

# client = OpenAI(api_key=OPENAI_KEY)

# Code to generate images from names in 100flowers.txt
# for prompt in Lines:
#   response = client.images.generate(
#     model="dall-e-3",
#     prompt=prompt,
#     size="1024x1024",
#     quality="standard",
#     n=1,
#   )

# Code to save generated images as png in Flowers folder
#   image_url = response.data[0].url
#   with urllib.request.urlopen(image_url) as image_url:
#       img = Image.open(BytesIO(image_url.read()))

#   img.save(f'./Flowers/{prompt}.png')


# from transformers.utils import send_example_telemetry

# send_example_telemetry("image_similarity_notebook", framework="pytorch")


# Creates a list of flower names
directory = './Flowers'
png_files = [file[:-len('.png')].strip() for file in os.listdir(directory) if file.endswith(".png")]


from datasets import Dataset, Image

# Gets list of file paths
def get_paths_to_images(images_directory):

  paths = []
  for file in os.listdir(images_directory):
    print(file)
    paths.append(file)

  return paths

# Creates dataset
def load_dataset(images_directory):

  paths_images = get_paths_to_images(images_directory)
  print(paths_images[0])
  dataset = Dataset.from_dict({"image": paths_images})

  return dataset

path_images = "./Flowers"
dataset = load_dataset(path_images)

from transformers import AutoFeatureExtractor, AutoModel

model_ckpt = "jafdxc/vit-base-patch16-224-finetuned-flower"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

import torchvision.transforms as T
import torch
from PIL import Image


# Data transformation chain.
transformation_chain = T.Compose(
    [
        # We first resize the input image to 256x256 and then we take center crop.
        T.Resize(int((256 / 224) * extractor.size["height"])),
        T.CenterCrop(extractor.size["height"]),
        T.ToTensor(),
        T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]
)
def extract_embeddings(model: torch.nn.Module):
    """Utility to compute embeddings."""
    device = model.device

    def pp(batch):
        images = batch["image"]
        image_batch_transformed = torch.stack(
            [transformation_chain(Image.open("./Flowers/" + image)) for image in images]
        )
        new_batch = {"pixel_values": image_batch_transformed.to(device)}
        with torch.no_grad():
            embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
        return {"embeddings": embeddings}

    return pp



import numpy as np
# Here, we map embedding extraction utility on our subset of candidate images.
batch_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = dataset.map(extract_fn, batched=True, batch_size=1)

all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)

print(all_candidate_embeddings.shape[0])

def compute_scores(emb_one, emb_two):
    """Computes cosine similarity between two vectors."""
    scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
    return scores.numpy().tolist()


def fetch_similar(image, top_k=5):
    """Fetches the `top_k` similar images with `image` as the query."""
    # Prepare the input query image for embedding computation.
    image_transformed = transformation_chain(image).unsqueeze(0)
    new_batch = {"pixel_values": image_transformed.to(device)}

    # Compute the embedding.
    with torch.no_grad():
        query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()

    # Compute similarity scores with all the candidate images at one go.
    # We also create a mapping between the candidate image identifiers
    # and their similarity scores with the query image.
    sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
    similarity_mapping = dict(zip([str(index) for index in range(all_candidate_embeddings.shape[0])], sim_scores))

    # Sort the mapping dictionary and return `top_k` candidates.
    similarity_mapping_sorted = dict(
        sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
    )
    id_entries = list(similarity_mapping_sorted.keys())[:top_k]

    ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
    return ids

import matplotlib.pyplot as plt


def plot_images(images):
   
   for image, name in images:
        if name == 'original':
            count = 0
            st.write("Showing the original image")
            st.image (image, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')

        else:
            count+=1
            st.write(f"Showing similar image {count}")
            img = Image.open(image)
            st.image (img, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')

# Streamlit webpage code
import streamlit as st
from io import StringIO

# Text Search
st.title("Flower Type Demo")
st.subheader("Upload an image of a Flower, you will get 5 flowers similar to it from our Dataset")

upload_file = st.file_uploader('Upload a Flower Image')

images = []

if upload_file:
    test_sample = Image.open(upload_file)

    sim_ids = fetch_similar(test_sample)

    for id in sim_ids:
        images.append(("./Flowers/" + candidate_subset_emb[id]["image"],candidate_subset_emb[id]["image"]))


    images.insert(0, (test_sample,'original'))
    print(images)
    plot_images(images)
    st.write("")