import multiprocessing
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from PIL.Image import Image, ANTIALIAS
import gradio as gr
from faiss import METRIC_INNER_PRODUCT
import requests
import pandas as pd
import os
import backoff
from functools import lru_cache
from huggingface_hub import list_models, ModelFilter, login
import copy

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

cpu_count = multiprocessing.cpu_count()

model = SentenceTransformer("clip-ViT-B-16")

def resize_image(image: Image, size: int = 224) -> Image:
    """Resizes an image retaining the aspect ratio."""
    w, h = image.size
    if w == h:
        image = image.resize((size, size), ANTIALIAS)
        return image
    if w > h:
        height_percent = size / float(h)
        width_size = int(float(w) * float(height_percent))
        image = image.resize((width_size, size), ANTIALIAS)
        return image
    if w < h:
        width_percent = size / float(w)
        height_size = int(float(w) * float(width_percent))
        image = image.resize((size, height_size), ANTIALIAS)
        return image

dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
dataset = dataset.filter(lambda x: x["embedding"] is not None)
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)

def get_nearest_k_examples(input, k):
    query = model.encode(input)
    # faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
    # threshold = 0.95
    # limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
    # images = dataset[indices]
    _, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
    images = retrieved_examples["image"][:k]
    last_modified = retrieved_examples["last_modified_date"]  # [:k]
    crawl_date = retrieved_examples["crawl_date"]  # [:k]
    metadata = [
        f"last_modified {modified}, crawl date:{crawl}"
        for modified, crawl in zip(last_modified, crawl_date)
    return list(zip(images, metadata))

def return_random_sample(k=27):
    sample = random.sample(range(len(dataset)), k)
    images = dataset[sample]["image"]
    return [resize_image(image).convert("RGB") for image in images]

def get_valid_hub_image_classification_model_ids():
    models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
    return { for model in models}

def predict_subset(model_id, token):
    # if token.value is None:
    #     raise gr.Error("Please enter a valid token")
    valid_model_ids = get_valid_hub_image_classification_model_ids()
    if model_id not in valid_model_ids:
        raise gr.Error(
            f"model_id {model_id} is not a valid image classification model id"
    except ValueError:
        raise gr.Error("Invalid Hub token")
    API_URL = f"{model_id}"
    headers = {"Authorization": f"Bearer {token}"}

    @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
    def _query(url):
        r =, headers=headers, data=url)
        return r

    def query(url):
        response = _query(url)
            data = response.json()
            argmax = data[0]
            return {"score": argmax["score"], "label": argmax["label"]}
        except Exception:
            return {"score": None, "label": None}

    # dataset2 = copy.deepcopy(dataset)
    # dataset2.drop_index("embedding")
    dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
    sample = random.sample(range(len(dataset)), 10)
    sample =
    predictions = []
    for row in sample:
        url = row["url"]
    gallery = []
    for url, prediction in zip(sample["url"], predictions):
        gallery.append((url, f"{prediction['label'], prediction['score']}"))
    # sample = x:  query(x['url']))
    labels = [d["label"] for d in predictions]
    from toolz import frequencies

    df = pd.DataFrame(
            "labels": frequencies(labels).keys(),
            "freqs": frequencies(labels).values(),
    return gallery, df

with gr.Blocks() as demo:
        """# ARCH Image Dataset Explorer 
    This [Gradio]( [Space]( allows you to explore an image dataset exported from [ARCH: Archive Research Compute Hub]( from the Internet Archive
    Each tab allows you to explore the dataset in a slightly different way by making use of Machine Learning models and tools from the Hugging Face ecosystem.
    **NOTE**: Images in the dataset are sourced from a collection generated from the web and may contain images that are Not Suitable for All.
    with gr.Tab("Random Image Gallery"):
            """## Random image gallery
        This tab allows you to explore images in your ARCH collection. You can refresh the images by clicking the refresh button.
        **Please note** not all images will be displayed as some images may not available via the original URLS anymore."""
        button = gr.Button("Refresh")
        gallery = gr.Gallery().style(grid=9, height="1400"), [], [gallery])
    with gr.Tab("Image Search"):
            """## Image search 
        You can search for images by entering a search term and clicking the search button.
        You can also change the number of images to be returned.
        This model uses the [clip-ViT-B-16]( model to embed your images and search term"""
        text = gr.Textbox(label="Search for images")
        k = gr.Slider(minimum=3, maximum=18, step=1)
        button = gr.Button("search")
        gallery = gr.Gallery().style(grid=3), [text, k], [gallery])
        # gr.Markdown(
        #     """### More info
        # ![]("""
        # )

    with gr.Tab("Image Classification Model Tester"):
            """## Image classification model tester
        You can use this to test out [image classification models]( on the Hugging Face Hub:
        - To use this tab you will need to have a Hugging Face account and a valid token. 
        - You can get a token from your [Hugging Face account page]( 
        - Input this token into the token box and then input a valid image classification model id from the Hub. For example `microsoft/resnet-50`. You can use the [Hub]( to find suitable models.
        This tab uses Hugging Face's [Inference API]( to make predictions. It will randomly select 10 images from your dataset and make predictions on them using your chosen model.
        **Please note** the predictions will take some time since the model needs to be loaded for inference first. If you make a second batch of prediction using the same model the predictions should be quicker."""
        token = gr.Textbox(label="token", type="password")

        model_id = gr.Textbox(
            label="model_id", value="davanstrien/autotrain-wikiart-sample2-42615108993"
        button = gr.Button("predict")
        gr.Markdown("## Results")
        plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
        gallery = gr.Gallery(), [model_id, token], [gallery, plot])
    with gr.Tab("Export to Label Studio format"):
        ## Export to Label Studio format
        <img align=left src="">
        This will export the current dataset to a csv file which can be imported into [Label Studio]( You can then import this into Label Studio to label your images by hand.
        You can run Label Studio using Hugging Face Spaces using this [Spaces template]("""

        dataset2 = copy.deepcopy(dataset)
        dataset2 = dataset2.remove_columns("image")
        dataset2 = dataset2.rename_column("url", "image")
        csv = dataset2.to_csv("label_studio.csv")
        csv_file = gr.File("label_studio.csv"), [], [csv_file])
demo.queue(concurrency_count=8, max_size=5)