import json
from argparse import ArgumentParser

import datasets
import torch
import transformers
from transformers import AutoModelForCausalLM, BatchEncoding

"""
Usage examples (with the best batch sizes on A100-80GB-400W)
============================================================
python -m  benchmark_hf_model  --model_name_or_path="Deci/DeciLM-7B"  --batch_size=352
python -m  benchmark_hf_model  --model_name_or_path="mistralai/Mistral-7B-v0.1"  --batch_size=192   --model_kwargs_json='{"use_flash_attention_2": true}'
python -m  benchmark_hf_model  --model_name_or_path="meta-llama/Llama-2-7b-hf"  --batch_size=48     --model_kwargs_json='{"use_flash_attention_2": true}'
"""


def parse_args():
    parser = ArgumentParser()

    parser.add_argument(
        "--model_name_or_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--warmup_iters",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--iterations",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
    )
    parser.add_argument(
        "--prompt_length",
        type=int,
        default=512,
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=512,
    )
    parser.add_argument(
        "--precision",
        type=str,
        default="bf16",
        help="Model precision, from: fp32, fp16 or bf16",
    )
    parser.add_argument(
        "--model_kwargs_json",
        type=str,
        default=None,
    )
    return parser.parse_args()


def main():
    args = parse_args()
    transformers.logging.set_verbosity_error()
    datasets.logging.set_verbosity_error()

    dict_precisions = {
        "fp32": torch.float32,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }
    if args.precision not in dict_precisions:
        raise ValueError(
            f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16"
        )
    dtype = dict_precisions[args.precision]

    model_kwargs = {}
    if args.model_kwargs_json is not None:
        model_kwargs = json.loads(args.model_kwargs_json)

    print(f"loading model...")
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True,
                                                 torch_dtype=dtype, **model_kwargs)
    try:
        print(model.model.layers[0].self_attn)
    except:
        print("couldn't print the model's attention module")

    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    model.cuda()
    model.eval()

    prompt = torch.ones(args.prompt_length, dtype=torch.long)
    inputs = BatchEncoding({"input_ids": prompt.repeat(args.batch_size, 1)})
    inputs = inputs.to(model.device)

    # warmup
    print(f"warming up for {args.warmup_iters} iterations...")
    for _ in range(args.warmup_iters):
        with torch.no_grad():
            _ = model.generate(
                **inputs,
                max_new_tokens=1,
                do_sample=False,
                eos_token_id=-1234,
            )
    print('finished warmup')
    torch.cuda.synchronize()

    print(
        f"prefill ({args.prompt_length} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}) + generation ({args.max_new_tokens} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}):")
    tokens_generated = args.max_new_tokens * args.batch_size
    prefill_and_generation = []
    for gen_iter in range(args.iterations):
        starter.record()
        with torch.no_grad():
            _ = model.generate(
                **inputs,
                max_new_tokens=args.max_new_tokens,
                do_sample=False,
                eos_token_id=-1234,
            )
        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender) / 1000
        prefill_and_generation.append(t)
        print(f"    iter {gen_iter + 1}:  {t:.03f} sec total, {tokens_generated / t:.02f} generated tokens/sec")
    aver = sum(prefill_and_generation) / len(prefill_and_generation)
    print(f"    average: {aver:.03f} sec total, {tokens_generated / aver:.02f} generated tokens/sec")
    print(f"These results are obtained for model '{args.model_name_or_path}' with {args.batch_size=}.")


if __name__ == "__main__":
    main()