taufeeque commited on
Commit
b2a4148
·
1 Parent(s): 3d3e3c9

Improve UI

Browse files
Files changed (1) hide show
  1. Code_Browser.py +19 -18
Code_Browser.py CHANGED
@@ -20,33 +20,28 @@ st.set_page_config(
20
 
21
  st.title("Codebook Features")
22
 
 
 
 
 
 
 
 
23
  base_cache_dir = "cache/"
24
  dirs = glob.glob(base_cache_dir + "models/*/")
25
  model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
26
  model_name_options = ["_".join(m) for m in model_name_options]
27
  model_name_options = sorted(set(model_name_options))
28
- def_model_idx = ["attn" in m for m in model_name_options].index(True)
29
 
30
- model_name = st.selectbox(
31
  "Model",
32
- model_name_options,
33
  index=def_model_idx,
34
  key=webapp_utils.persist("model_name"),
35
  )
36
-
37
  model = model_name.split("_")[0].split("#")[0]
38
- model_layers = {
39
- "pythia-410m-deduped": 24,
40
- "pythia-70m-deduped": 6,
41
- "gpt2": 12,
42
- "TinyStories-1Layer-21M": 1,
43
- }
44
- model_heads = {
45
- "pythia-410m-deduped": 16,
46
- "pythia-70m-deduped": 8,
47
- "gpt2": 12,
48
- "TinyStories-1Layer-21M": 16,
49
- }
50
  ccb = model_name.split("_")[1]
51
  ccb = "_ccb" if ccb == "ccb" else ""
52
  cb_at = "_".join(model_name.split("_")[2:])
@@ -59,12 +54,12 @@ dirs.sort(key=os.path.getmtime)
59
 
60
  # session states
61
  is_attn = "attn" in cb_at
62
- num_layers = model_layers[model]
63
- num_heads = model_heads[model]
64
  codes_cache_path = dirs[-1] + "/"
65
 
66
  model_info = code_search_utils.parse_model_info(codes_cache_path)
67
  num_codes = model_info.num_codes
 
 
68
  dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
69
 
70
  (
@@ -96,6 +91,12 @@ if not DEPLOY_MODE:
96
  st.write(metrics)
97
 
98
  st.markdown("## Demo Codes")
 
 
 
 
 
 
99
  demo_file_path = codes_cache_path + "demo_codes.txt"
100
 
101
  if st.checkbox("Show Demo Codes"):
 
20
 
21
  st.title("Codebook Features")
22
 
23
+ pretty_model_names = {
24
+ "TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP",
25
+ "TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories-1L-21M-Attn",
26
+ "TinyStories-33M_ccb_attn_preproj": "TinyStories-4L-33M-Attn",
27
+ }
28
+ orig_model_name = {v: k for k, v in pretty_model_names.items()}
29
+
30
  base_cache_dir = "cache/"
31
  dirs = glob.glob(base_cache_dir + "models/*/")
32
  model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
33
  model_name_options = ["_".join(m) for m in model_name_options]
34
  model_name_options = sorted(set(model_name_options))
35
+ def_model_idx = ["attn" in m.lower() for m in model_name_options].index(True)
36
 
37
+ p_model_name = st.selectbox(
38
  "Model",
39
+ [pretty_model_names.get(m, m) for m in model_name_options],
40
  index=def_model_idx,
41
  key=webapp_utils.persist("model_name"),
42
  )
43
+ model_name = orig_model_name.get(p_model_name, p_model_name)
44
  model = model_name.split("_")[0].split("#")[0]
 
 
 
 
 
 
 
 
 
 
 
 
45
  ccb = model_name.split("_")[1]
46
  ccb = "_ccb" if ccb == "ccb" else ""
47
  cb_at = "_".join(model_name.split("_")[2:])
 
54
 
55
  # session states
56
  is_attn = "attn" in cb_at
 
 
57
  codes_cache_path = dirs[-1] + "/"
58
 
59
  model_info = code_search_utils.parse_model_info(codes_cache_path)
60
  num_codes = model_info.num_codes
61
+ num_layers = model_info.n_layers
62
+ num_heads = model_info.n_heads
63
  dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
64
 
65
  (
 
91
  st.write(metrics)
92
 
93
  st.markdown("## Demo Codes")
94
+ demo_codes_desc = (
95
+ "This section contains codes that we've found to be interpretable along "
96
+ "with a description of the feature we think they are capturing."
97
+ "Click on the 🔍 search button for a code to see the tokens that code activates on."
98
+ )
99
+ st.write(demo_codes_desc)
100
  demo_file_path = codes_cache_path + "demo_codes.txt"
101
 
102
  if st.checkbox("Show Demo Codes"):