#!/usr/bin/env python3

import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

from update_readme import generate_url, get_all_files


class Wheel:
    def __init__(self, full_name: str, url: str):
        """
        Args:
          full_name:
            Example: k2-1.23.4.dev20230223+cpu.torch1.10.0-cp36-cp36m-linux_x86_64.whl
        """
        self.full_name = full_name
        #  pattern = r"k2-(\d)\.(\d+)(\.(\d+))?\.dev(\d{8})+cpu\.torch(\d\.\d+)"
        pattern = (
            r"k2-(\d)\.(\d)+((\.)(\d))?\.dev(\d{8})\+cpu\.torch(\d\.\d+\.\d)-cp(\d+)"
        )
        m = re.search(pattern, full_name)

        self.k2_major = int(m.group(1))
        self.k2_minor = int(m.group(2))
        self.k2_patch = int(m.group(5))
        self.k2_date = int(m.group(6))
        self.torch_version = m.group(7)
        self.py_version = int(m.group(8))
        self.url = url

    def __str__(self):
        return self.url

    def __repr__(self):
        return self.url


def generate_index(filename: str, torch_versions) -> str:
    b = []
    for i in torch_versions:
        b.append(f"   ./{i}.rst")
    b = "\n".join(b)

    s = f"""\
Pre-compiled CPU wheels (Linux)
===============================

This page describes pre-compiled ``CPU`` wheels for `k2`_ on Linux.

.. toctree::
   :maxdepth: 2

{b}
    """
    with open(filename, "w") as f:
        f.write(s)


def sort_by_wheel(x: Wheel):
    return x.k2_major, x.k2_minor, x.k2_patch, x.k2_date, x.py_version


def sort_by_torch(x):
    major, minor, patch = x.split(".")
    return int(major), int(minor), int(patch)


def get_all_torch_versions(wheels: List[Wheel]) -> List[str]:
    ans = set()
    for w in wheels:
        ans.add(w.torch_version)

    # sort torch version from high to low
    ans = list(ans)
    ans.sort(reverse=True, key=sort_by_torch)
    return ans


def get_doc_dir():
    k2_dir = os.getenv("K2_DIR")
    if k2_dir is None:
        raise ValueError("Please set the environment variable k2_dir")

    cpu_dir = Path(k2_dir) / "docs/source/installation/pre-compiled-cpu-wheels-linux"

    if not Path(cpu_dir).is_dir():
        raise ValueError(f"{cpu_dir} does not exist")

    print(f"k2 doc cpu_dir: {cpu_dir}")
    return cpu_dir


def remove_all_files(d: str):
    files = get_all_files(d, "*.rst")
    for f in files:
        print(f"removing {f}")
        os.remove(f)


def get_all_cpu_wheels():
    cpu = get_all_files("cpu", suffix="*.whl")
    cpu_wheels = generate_url(cpu)
    return cpu_wheels


def generate_file(d: str, torch_version: str, wheels: List[Wheel]) -> str:
    s = f"torch {torch_version}\n"
    s += "=" * len(f"torch {torch_version}")
    s += "\n" * 3
    wheels = filter(lambda w: w.torch_version == torch_version, wheels)
    wheels = list(wheels)
    wheels.sort(reverse=True, key=sort_by_wheel)
    for w in wheels:
        s += f"- `{w.full_name} <{w.url}>`_\n"

    with open(f"{d}/{torch_version}.rst", "w") as f:
        f.write(s)


def main():
    d = get_doc_dir()
    remove_all_files(d)

    urls = get_all_cpu_wheels()

    wheels = []
    for url in urls:
        full_name = url.rsplit("/", maxsplit=1)[1]
        wheels.append(Wheel(full_name, url))
    torch_versions = get_all_torch_versions(wheels)

    content = []
    for t in torch_versions:
        s = generate_file(d, t, wheels)

    generate_index(f"{d}/index.rst", torch_versions)


if __name__ == "__main__":
    main()