|
import json |
|
import math |
|
import os |
|
import time |
|
from argparse import ArgumentParser |
|
from collections import defaultdict |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "0" |
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
class TorchTracemalloc: |
|
track_memory_consumption = [] |
|
|
|
def __enter__(self): |
|
self.begin = torch.cuda.memory_allocated() |
|
torch.cuda.reset_max_memory_allocated() |
|
return self |
|
|
|
def __exit__(self, *exc): |
|
peak = torch.cuda.max_memory_allocated() |
|
peaked = (peak - self.begin) // 1024**2 |
|
TorchTracemalloc.track_memory_consumption.append(peaked) |
|
|
|
|
|
def save_bar_chart(title, x, y, ylabel, xlabel, output_path): |
|
try: |
|
plt.style.use("ggplot") |
|
width = 0.4 |
|
xs = np.arange(len(x)) |
|
plt.figure(figsize=(10, 6)) |
|
plt.bar(xs, height=y, width=width, color="skyblue") |
|
plt.title(title) |
|
plt.xticks(xs, x) |
|
plt.xlabel(xlabel) |
|
plt.ylabel(ylabel) |
|
plt.savefig(output_path) |
|
except Exception as e: |
|
print(f"Error saving chart {title}: {str(e)}") |
|
finally: |
|
plt.close() |
|
|
|
|
|
def format_response(dialog, response): |
|
formatted_dialog = dialog.copy() |
|
formatted_dialog.append({"role": "assistant", "content": response}) |
|
return formatted_dialog |
|
|
|
|
|
parser = ArgumentParser("chat_with_llama") |
|
|
|
parser.add_argument( |
|
"--llama", type=str, default="3-instruct", choices=["2", "3-instruct"] |
|
) |
|
parser.add_argument("--prompts_path", type=str, default="chats_sys_none.json") |
|
|
|
parser.add_argument("--model_size", type=int, default=8, choices=[7, 8, 13]) |
|
parser.add_argument("--num_new_tokens", type=int, default=512) |
|
parser.add_argument( |
|
"--temperature", type=float, default=0.4, help="Temperature for sampling" |
|
) |
|
parser.add_argument("--window_length", type=int, default=32) |
|
parser.add_argument("--kv_bits", type=int, default=1) |
|
parser.add_argument("--output_path", type=str, default="./output") |
|
parser.add_argument( |
|
"--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"] |
|
) |
|
args = parser.parse_args() |
|
bits = args.kv_bits |
|
|
|
try: |
|
if args.llama == 2: |
|
model_name = "NousResearch/Llama-2-7b-hf" |
|
else: |
|
model_name = "NousResearch/Meta-Llama-3-8B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
special_tokens = {"pad_token": "<PAD>"} |
|
tokenizer.add_special_tokens(special_tokens) |
|
|
|
config = AutoConfig.from_pretrained(model_name) |
|
|
|
if isinstance(bits, int): |
|
if args.llama == 2: |
|
setattr( |
|
config, |
|
"quantizer_path", |
|
f"codebooks/llama-2-7b_{bits}bit.xmad", |
|
) |
|
else: |
|
setattr( |
|
config, |
|
"quantizer_path", |
|
f"codebooks/llama-3-8b_{bits}bit.xmad", |
|
) |
|
if isinstance(args.window_length, int): |
|
setattr(config, "window_length", args.window_length) |
|
|
|
if args.dtype == "bf16": |
|
dtype = torch.bfloat16 |
|
elif args.dtype == "fp16": |
|
dtype = torch.float16 |
|
elif args.dtype == "fp32": |
|
dtype = torch.float32 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, config=config, torch_dtype=dtype, device_map="auto" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: |
|
print( |
|
"WARNING: Resizing the embedding matrix to match the tokenizer vocab size." |
|
) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
tokenizer.padding_side = "left" |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
with open(args.prompts_path, "r") as file: |
|
dialogs = json.load(file) |
|
|
|
num_dialogs = len(dialogs) |
|
print(f"Loaded {num_dialogs} dialogues...") |
|
|
|
|
|
batch_inputs = [ |
|
tokenizer.apply_chat_template( |
|
dialog, tokenize=False, add_generation_prompt=True |
|
) |
|
for dialog in dialogs |
|
] |
|
|
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_sizes = [ |
|
30 |
|
] |
|
|
|
memory_avg = [] |
|
tokens_per_sec_avg = [] |
|
time_to_first_token_avg = [] |
|
responses_by_batch_size = defaultdict(list) |
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
for batch_size in batch_sizes: |
|
print(f"\nProcessing with batch size: {batch_size}") |
|
|
|
|
|
actual_batch_size = 30 |
|
total_time = 0 |
|
total_tokens = 0 |
|
total_ttft = 0 |
|
num_batches = math.ceil(num_dialogs / actual_batch_size) |
|
|
|
|
|
with TorchTracemalloc() as tt: |
|
for i in range(0, num_dialogs, actual_batch_size): |
|
batch = batch_inputs[i : i + actual_batch_size] |
|
|
|
try: |
|
encoded_inputs = tokenizer( |
|
batch, |
|
padding=True, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = encoded_inputs["input_ids"].to(model.device) |
|
attention_mask = encoded_inputs["attention_mask"].to( |
|
model.device |
|
) |
|
|
|
torch.cuda.synchronize() |
|
start_time = time.perf_counter() |
|
|
|
|
|
with torch.no_grad(): |
|
output_tokens = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=args.num_new_tokens, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=args.temperature, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=terminators, |
|
) |
|
|
|
torch.cuda.synchronize() |
|
end_time = time.perf_counter() |
|
|
|
batch_time = end_time - start_time |
|
total_time += batch_time |
|
total_tokens += output_tokens.numel() |
|
|
|
if i == 0: |
|
total_ttft = batch_time |
|
|
|
|
|
decoded_outputs = tokenizer.batch_decode( |
|
output_tokens, skip_special_tokens=True |
|
) |
|
|
|
|
|
for j, response in enumerate(decoded_outputs): |
|
original_dialog = dialogs[i + j] |
|
formatted_response = format_response( |
|
original_dialog, response |
|
) |
|
responses_by_batch_size[batch_size].append( |
|
formatted_response |
|
) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
print( |
|
f"Error processing batch {i//batch_size + 1}: {str(e)}" |
|
) |
|
continue |
|
|
|
avg_memory = np.mean(TorchTracemalloc.track_memory_consumption) |
|
memory_avg.append(avg_memory) |
|
|
|
tokens_per_sec = total_tokens / total_time if total_time > 0 else 0 |
|
tokens_per_sec_avg.append(tokens_per_sec) |
|
|
|
|
|
time_to_first_token = ( |
|
total_ttft / actual_batch_size if actual_batch_size > 0 else 0 |
|
) |
|
time_to_first_token_avg.append(time_to_first_token) |
|
|
|
print(f"Actual Batch Size Used: {actual_batch_size}") |
|
print(f"GPU Memory Consumption (MiB): {avg_memory:.2f} MiB") |
|
print(f"Tokens per Second: {tokens_per_sec:.2f}") |
|
print(f"TTFT (seconds): {time_to_first_token:.4f} seconds") |
|
|
|
for batch_size, responses in responses_by_batch_size.items(): |
|
output_file = os.path.join( |
|
args.output_path, f"batch_{batch_size}_responses.json" |
|
) |
|
with open(output_file, "w") as f: |
|
json.dump(responses, f, indent=2) |
|
|
|
save_bar_chart( |
|
title="GPU Memory Consumption as a Function of Batch Size", |
|
x=batch_sizes, |
|
y=memory_avg, |
|
xlabel="Batch Size", |
|
ylabel="GPU Memory Consumption (MiB)", |
|
output_path=f"{args.output_path}/memory_usage.png", |
|
) |
|
|
|
save_bar_chart( |
|
title="Number of Tokens per Second as a Function of Batch Size", |
|
x=batch_sizes, |
|
y=tokens_per_sec_avg, |
|
xlabel="Batch Size", |
|
ylabel="Tokens per Second", |
|
output_path=f"{args.output_path}/tokens_per_second.png", |
|
) |
|
|
|
save_bar_chart( |
|
title="Time to First Token (TTFT) as a Function of Batch Size", |
|
x=batch_sizes, |
|
y=time_to_first_token_avg, |
|
xlabel="Batch Size", |
|
ylabel="TTFT (seconds)", |
|
output_path=f"{args.output_path}/time_to_first_token.png", |
|
) |
|
|
|
print("\nBenchmarking Results:") |
|
print(f"Batch Sizes: {batch_sizes}") |
|
print(f"GPU Memory Consumption (MiB): {memory_avg}") |
|
print(f"Tokens per Second: {tokens_per_sec_avg}") |
|
print(f"Time to First Token (seconds): {time_to_first_token_avg}") |
|
|
|
print( |
|
f"\nModel size: {args.model_size}, Max New Tokens: {args.num_new_tokens}, KV bits: {bits}" |
|
) |
|
print(f"Results and responses saved in: {args.output_path}") |
|
|
|
except Exception as e: |
|
print(f"An error occurred during script execution: {str(e)}") |
|
|