import json
from tempfile import mktemp

import argilla as rg
from huggingface_hub import HfApi

from defaults import REMOTE_CODE_PATHS, SEED_DATA_PATH


hf_api = HfApi()

with open("DATASET_README_BASE.md") as f:
    DATASET_README_BASE = f.read()


def create_readme(domain_seed_data, project_name, domain):
    # create a readme for the project that shows the domain and project name
    readme = DATASET_README_BASE
    readme += f"# {project_name}\n\n## Domain: {domain}"
    perspectives = domain_seed_data.get("perspectives")
    topics = domain_seed_data.get("topics")
    examples = domain_seed_data.get("examples")
    if perspectives:
        readme += "\n\n## Perspectives\n\n"
        for p in perspectives:
            readme += f"- {p}\n"
    if topics:
        readme += "\n\n## Topics\n\n"
        for t in topics:
            readme += f"- {t}\n"
    if examples:
        readme += "\n\n## Examples\n\n"
        for example in examples:
            readme += f"### {example['question']}\n\n{example['answer']}\n\n"
    temp_file = mktemp()

    with open(temp_file, "w") as f:
        f.write(readme)
    return temp_file


def setup_dataset_on_hub(repo_id, hub_token):
    # create an empty dataset repo on the hub
    hf_api.create_repo(
        repo_id=repo_id,
        token=hub_token,
        repo_type="dataset",
        exist_ok=True,
    )


def push_dataset_to_hub(
    domain_seed_data_path,
    project_name,
    domain,
    pipeline_path,
    hub_username,
    hub_token: str,
):
    repo_id = f"{hub_username}/{project_name}"

    setup_dataset_on_hub(repo_id=repo_id, hub_token=hub_token)

    #  upload the seed data and readme to the hub
    hf_api.upload_file(
        path_or_fileobj=domain_seed_data_path,
        path_in_repo="seed_data.json",
        token=hub_token,
        repo_id=repo_id,
        repo_type="dataset",
    )

    # upload the readme to the hub
    domain_seed_data = json.load(open(domain_seed_data_path))
    hf_api.upload_file(
        path_or_fileobj=create_readme(
            domain_seed_data=domain_seed_data, project_name=project_name, domain=domain
        ),
        path_in_repo="README.md",
        token=hub_token,
        repo_id=repo_id,
        repo_type="dataset",
    )


def push_pipeline_to_hub(
    pipeline_path,
    hub_username,
    hub_token: str,
    project_name,
):
    repo_id = f"{hub_username}/{project_name}"

    # upload the pipeline to the hub
    hf_api.upload_file(
        path_or_fileobj=pipeline_path,
        path_in_repo="pipeline.py",
        token=hub_token,
        repo_id=repo_id,
        repo_type="dataset",
    )

    for code_path in REMOTE_CODE_PATHS:
        hf_api.upload_file(
            path_or_fileobj=code_path,
            path_in_repo=code_path,
            token=hub_token,
            repo_id=repo_id,
            repo_type="dataset",
        )

    print(f"Dataset uploaded to {repo_id}")


def pull_seed_data_from_repo(repo_id, hub_token):
    # pull the dataset repo from the hub
    hf_api.hf_hub_download(
        repo_id=repo_id, token=hub_token, repo_type="dataset", filename=SEED_DATA_PATH
    )
    return json.load(open(SEED_DATA_PATH))


def push_argilla_dataset_to_hub(
    name: str,
    repo_id: str,
    url: str,
    api_key: str,
    hub_token: str,
    workspace: str = "admin",
):
    rg.init(api_url=url, api_key=api_key)
    feedback_dataset = rg.FeedbackDataset.from_argilla(name=name, workspace=workspace)
    local_dataset = feedback_dataset.pull()
    local_dataset.push_to_huggingface(repo_id=repo_id, token=hub_token)


def push_pipeline_params(
    pipeline_params,
    hub_username,
    hub_token: str,
    project_name,
):
    repo_id = f"{hub_username}/{project_name}"
    temp_path = mktemp()
    with open(temp_path, "w") as f:
        json.dump(pipeline_params, f)
    # upload the pipeline to the hub
    hf_api.upload_file(
        path_or_fileobj=temp_path,
        path_in_repo="pipeline_params.json",
        token=hub_token,
        repo_id=repo_id,
        repo_type="dataset",
    )

    print(f"Pipeline params uploaded to {repo_id}")