gaunernst commited on
Commit
9b38236
·
1 Parent(s): 63f212d

initial prototype

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoConfig
3
+
4
+
5
+ # TODO: access token for gated repo
6
+ # TODO: add note about sliding window attention
7
+ def calculate(name: str, ctx_len: int, num_users: int, dtype: str):
8
+ cfg = AutoConfig.from_pretrained(name, trust_remote_code=True)
9
+ use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3"))
10
+
11
+ if hasattr(cfg, "text_config"):
12
+ cfg = cfg.text_config
13
+
14
+ num_layers = cfg.num_hidden_layers
15
+ model_config = [
16
+ ["num_layers", num_layers],
17
+ ["max_ctx_len", cfg.max_position_embeddings],
18
+ ]
19
+
20
+ # TODO: show attention type, show calculation
21
+ if use_mla:
22
+ kv_lora_rank = cfg.kv_lora_rank
23
+ qk_rope_head_dim = cfg.qk_rope_head_dim
24
+ nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim)
25
+
26
+ model_config.append(["kv_lora_rank", kv_lora_rank])
27
+ model_config.append(["qk_rope_head_dim", qk_rope_head_dim])
28
+
29
+ else:
30
+ num_kv_heads = cfg.num_key_value_heads
31
+ head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
32
+ nelems_per_token = num_layers * num_kv_heads * head_dim * 2
33
+
34
+ model_config.append(["num_kv_heads", num_kv_heads])
35
+ model_config.append(["head_dim", head_dim])
36
+
37
+ if dtype == "fp16/bf16":
38
+ nbytes_per_elem = 2
39
+ elif dtype == "fp8":
40
+ nbytes_per_elem = 1 + 2 / cfg.hidden_size # assume per-token scaling
41
+
42
+ kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9
43
+ return kv_cache_size, model_config
44
+
45
+
46
+ demo = gr.Interface(
47
+ fn=calculate,
48
+ inputs=[
49
+ gr.Textbox(label="model_id", value="google/gemma-3-1b-it"),
50
+ gr.Number(label="Context length", value=128_000),
51
+ gr.Number(label="No. of users", value=1),
52
+ gr.Dropdown(label="KV cache dtype", choices=["fp16/bf16", "fp8"]),
53
+ ],
54
+ outputs=[
55
+ gr.Number(label="KV cache size (GB)", precision=2),
56
+ gr.Dataframe(
57
+ label="Model config", headers=["Key", "Value"], datatype=["str", "int"]
58
+ ),
59
+ ],
60
+ )
61
+ demo.launch()