throughput-calculator / src /throughput_utils.py
FL33TW00D
chore: init
dc80200 unverified
raw
history blame
5.63 kB
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import ScalarFormatter
from enum import Enum
import io
class AttentionType(Enum):
LOCAL = 0
GLOBAL = 1
def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size):
return 2 * kv_parameter_size * n_kv_heads * d_head
def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size):
return kv_parameter_size * d_compressed
def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size):
return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size)
def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, bandwidth):
tps_values = []
for ctx_len in seq_len:
total_kv_size = 0
for l in range(num_layers):
if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL:
total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size)
else:
total_kv_size += kv_per_layer_per_token * ctx_len
tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size)
tps_values.append(tps)
return tps_values
def create_throughput_plot(
model_name,
memory_bandwidth,
num_parameters,
parameter_size,
kv_parameter_size,
num_layers,
num_heads,
d_model,
ctx_length,
local_layers,
global_layers,
swa_size,
gqa_heads,
mla_d_compressed,
):
memory_bandwidth = float(memory_bandwidth) * 1_000_000_000
num_parameters = float(num_parameters) * 1_000_000_000
d_head = d_model // num_heads
total_param_size = num_parameters * (parameter_size / 8.0)
swa_pattern = ([AttentionType.LOCAL] * local_layers +
[AttentionType.GLOBAL] * global_layers)
if len(swa_pattern) == 0:
swa_pattern = [AttentionType.GLOBAL]
sns.set_theme(style="whitegrid", context="paper")
palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed))
plt.figure(figsize=(14, 8), dpi=300)
seq_len = np.logspace(2, 5, 100).astype(int)
batch_size = 1
tps_values = []
gqa_count = len(gqa_heads)
for i, n_kv_head in enumerate(gqa_heads):
n_kv_head = int(n_kv_head)
kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size)
gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, memory_bandwidth)
tps_values.extend(gqa_tps_values)
plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i],
linewidth=3.5, alpha=0.85)
plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5,
label=f"Max Context Length ({ctx_length:,})")
local_count = swa_pattern.count(AttentionType.LOCAL)
global_count = swa_pattern.count(AttentionType.GLOBAL)
if local_count > 0:
plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5,
label=f"Sliding Window Limit ({swa_size:,})")
for i, d_comp in enumerate(mla_d_compressed):
d_comp = int(d_comp)
kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size)
mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, memory_bandwidth)
tps_values.extend(mla_tps_values)
plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}",
color=palette[i + gqa_count], linewidth=3.5, alpha=0.85)
plt.xscale('log')
if all(np.isfinite(tps_values)):
min_tps = min(tps_values)
max_tps = max(tps_values)
y_min = max(0, min_tps * 0.9)
y_max = max_tps * 1.1
plt.ylim(y_min, y_max)
else:
plt.ylim(15, 40)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}"
device_name = model_name.split(':')[0] if ':' in model_name else model_name
plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}",
xy=(0.8, 0.97),
xycoords='axes fraction',
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'),
va='top',
fontsize=11)
plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold')
plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold')
plt.tick_params(axis='both', which='major', labelsize=12)
model_title = model_name.split(':')[1] if ':' in model_name else model_name
plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18,
fontweight='bold', pad=20)
plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14)
plt.grid(True, alpha=0.5)
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
from PIL import Image
img = Image.open(buf)
return img