import gradio as gr

from data import download_dataset, tokenize_dataset, load_tokenized_dataset
from infer import get_model_and_tokenizer, batch_embed

# TODO: add instructor models
# "hkunlp/instructor-xl",
# "hkunlp/instructor-large",
# "hkunlp/instructor-base",

# model ids and hidden sizes
models_and_hidden_sizes = [
    ("intfloat/e5-small-v2", 384),
    ("intfloat/e5-base-v2", 768),
    ("intfloat/e5-large-v2", 1024),
    ("intfloat/multilingual-e5-small", 384),
    ("intfloat/multilingual-e5-base", 768),
    ("intfloat/multilingual-e5-large", 1024),
    ("sentence-transformers/all-MiniLM-L6-v2", 384),
    ("sentence-transformers/all-MiniLM-L12-v2", 384),
    ("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 384),
]

model_options = [
    f"{model_name} (hidden_size = {hidden_size})"
    for model_name, hidden_size in models_and_hidden_sizes
]


opt2desc = {
    "O2": "Most precise, slowest (O2: basic and extended general optimizations, transformers-specific fusions)",
    "O3": "Less precise, faster (O3: O2 + gelu approx)",
    "O4": "Least precise, fastest (O4: O3 + fp16/bf16)",
}

desc2opt = {v: k for k, v in opt2desc.items()}


optimization_options = list(opt2desc.values())


def download_and_tokenize(
    ds_name,
    ds_config,
    column_name,
    ds_split,
    model_choice,
    opt_desc,
    num2skip,
    num2embed,
    progress=gr.Progress(track_tqdm=True),
):
    num_samples = download_dataset(ds_name, ds_config, ds_split)

    opt_level = desc2opt[opt_desc]

    model_name = model_choice.split()[0]

    tokenize_dataset(
        ds_name=ds_name,
        ds_config=ds_config,
        model_name=model_name,
        opt_level=opt_level,
        column_name=column_name,
        num2skip=num2skip,
        num2embed=num2embed,
    )

    return f"Downloaded! It has {len(num_samples)} docs."


def embed(
    ds_name,
    ds_config,
    column_name,
    ds_split,
    model_choice,
    opt_desc,
    new_dataset_id,
    num2skip,
    num2embed,
    progress=gr.Progress(track_tqdm=True),
):
    ds = load_tokenized_dataset(ds_name, ds_config, ds_split)

    opt_level = desc2opt[opt_desc]

    model_name = model_choice.split()[0]

    if progress is not None:
        progress(0.2, "Downloading model and tokenizer...")
    model, tokenizer = get_model_and_tokenizer(model_name, opt_level, progress)

    doc_count, seconds_taken = batch_embed(
        ds,
        model,
        tokenizer,
        model_name=model_name,
        column_name=column_name,
        new_dataset_id=new_dataset_id,
        opt_level=opt_level,
        num2skip=num2skip,
        num2embed=num2embed,
        progress=progress,
    )

    return f"Embedded {doc_count} docs in {seconds_taken/60:.2f} minutes ({doc_count/seconds_taken:.1f} docs/sec)"


with gr.Blocks(title="Bulk embeddings") as demo:
    gr.Markdown(
        """
        # Bulk Embeddings

        
        This Space allows you to embed a large dataset easily. For instance, this can easily create vectors for Wikipedia \
        articles -- taking about __ hours and costing approximately $__. 
        This utilizes state-of-the-art open-source embedding models, \
        and optimizes them for inference using Hugging Face [optimum](https://github.com/huggingface/optimum). There are various \
        levels of optimizations that can be applied - the quality of the embeddings will degrade as the optimizations increase.  
        Currently available options: O2/O3/O4 on T4/A10 GPUs using onnx runtime.  
        Future options: 
          - OpenVino for CPU inference
          - TensorRT for GPU inference
          - Quantized models
          - Instructor models
          - Text splitting options
          - More control about which rows to embed (skip some, stop early)
          - Dynamic padding

        ## Steps
        1. Upload the dataset to the Hugging Face Hub.
        2. Enter dataset details into the form below.
        3. Choose a model. These are taken from the top of the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
        4. Enter optimization level. See [here](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration) for details.
        5. Choose a name for the new dataset.
        6. Hit run!

        ### Note:
        If you have short documents, O3 will be faster than O4. If you have long documents, O4 will be faster than O3. \
            O4 requires the tokenized documents to be padded to max length.
        """
    )

    with gr.Row():
        ds_name = gr.Textbox(
            lines=1,
            label="Dataset to load from Hugging Face Hub",
            value="wikipedia",
        )
        ds_config = gr.Textbox(
            lines=1,
            label="Dataset config (leave blank to use default)",
            value="20220301.en",
        )

        column_name = gr.Textbox(lines=1, label="Enter column to embed", value="text")
        ds_split = gr.Dropdown(
            choices=["train", "validation", "test"],
            label="Dataset split",
            value="train",
        )
        # TODO: idx column
        # TODO: text splitting options

    with gr.Row():
        model_choice = gr.Dropdown(
            choices=model_options, label="Embedding model", value=model_options[0]
        )
        opt_desc = gr.Dropdown(
            choices=optimization_options,
            label="Optimization level",
            value=optimization_options[0],
        )

    with gr.Row():
        new_dataset_id = gr.Textbox(
            lines=1,
            label="New dataset name, including username",
            value="wiki-embeds",
        )

        num2skip = gr.Slider(
            value=0,
            minimum=0,
            maximum=100_000_000,
            step=1,
            label="Number of rows to skip",
        )

        num2embed = gr.Slider(
            value=30000,
            minimum=-1,
            maximum=100_000_000,
            step=1,
            label="Number of rows to embed (-1 = all)",
        )

        num2upload = gr.Slider(
            value=10000,
            minimum=1000,
            maximum=100000,
            step=1000,
            label="Chunk size for uploading",
        )

    with gr.Row():
        download_btn = gr.Button(value="Download and tokenize dataset!")
        embed_btn = gr.Button(value="Embed texts!")

        last = gr.Textbox(value="")

    download_btn.click(
        fn=download_and_tokenize,
        inputs=[
            ds_name,
            ds_config,
            column_name,
            ds_split,
            model_choice,
            opt_desc,
            num2skip,
            num2embed,
        ],
        outputs=last,
    )

    embed_btn.click(
        fn=embed,
        inputs=[
            ds_name,
            ds_config,
            column_name,
            ds_split,
            model_choice,
            opt_desc,
            new_dataset_id,
            num2skip,
            num2embed,
        ],
        outputs=last,
    )


if __name__ == "__main__":
    demo.queue(concurrency_count=20).launch(show_error=True, debug=True)