from __future__ import annotations import glob import io import os import random import struct from contextlib import contextmanager from html import escape import msgpack import streamlit as st import torch import tqdm from huggingface_hub import HfFileSystem from transformers import AutoTokenizer st.set_page_config(layout="wide") MODEL_NAME = os.environ.get("MODEL_NAME", "MonetLLM/monet-vd-1.4B-100BT-hf") CONTEXT_WINDOW = int(os.environ.get("CONTEXT_WINDOW", "12")) CANDIDATE_THRESHOLD = int(os.environ.get("CANDIDATE_THRESHOLD", "50")) HORIZONTAL_STYLE = """""" @st.cache_resource def prepare_routing_resources(): fs = HfFileSystem() for filename in fs.glob(f"datasets/{MODEL_NAME}-viewer-data/*"): if not os.path.exists(os.path.basename(filename)): print(f"[*] Download {filename}...") fs.download(filename, ".") input_tokens = torch.load("inputs.pt") examples_tables = [] for i in tqdm.trange(len(glob.glob("examples-*.msgpack"))): with open(f"examples-{i}.msgpack", "rb") as fp: fp.seek(-4, io.SEEK_END) table_size = struct.unpack(">I", fp.read(4))[0] fp.seek(-(table_size + 4), io.SEEK_END) examples_tables.append(msgpack.Unpacker(fp).unpack()) candidates = [] for i, table in enumerate(tqdm.tqdm(examples_tables)): candidates.append([]) with open(f"examples-{i}.msgpack", "rb") as fp: unpacker = msgpack.Unpacker(fp) for j in range(len(table)): if len(unpacker.unpack()) > CANDIDATE_THRESHOLD: candidates[-1].append(j) routing_tables = [] for i in tqdm.trange(len(examples_tables)): with open(f"routings-{i}.msgpack", "rb") as fp: fp.seek(-4, io.SEEK_END) table_size = struct.unpack(">I", fp.read(4))[0] fp.seek(-(table_size + 4), io.SEEK_END) routing_tables.append(msgpack.Unpacker(fp).unpack()) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) return input_tokens, examples_tables, routing_tables, candidates, tokenizer input_tokens, examples_tables, routing_tables, candidates, tokenizer = ( prepare_routing_resources() ) def render_routing_examples_in_html(router_index: int, expert_id: int) -> str: with open(f"examples-{router_index}.msgpack", "rb") as fp: fp.seek(examples_tables[router_index][expert_id]) examples = msgpack.Unpacker(fp).unpack() with open(f"routings-{router_index}.msgpack", "rb") as fp: table = [] for i, j, _ in examples: start = max(j - CONTEXT_WINDOW, 0) end = min(j + CONTEXT_WINDOW, len(routing_tables[router_index][i])) fp.seek(routing_tables[router_index][i][start]) unpacker = msgpack.Unpacker(fp, strict_map_key=False) activated = [unpacker.unpack().get(expert_id, 0) for _ in range(start, end)] full_text = tokenizer.decode(input_tokens[i]) encodings = tokenizer(full_text, add_special_tokens=False) offset = len(encodings.input_ids) - input_tokens.size(1) spans, lslice = [], None for k in range(start, end): if offset + k >= 0 and (sslice := encodings.token_to_chars(offset + k)): span, score = full_text[slice(*sslice)], activated[k - start] if lslice == sslice: score = max(spans.pop(-1)[1], score) spans.append((escape(span), score)) lslice = sslice spans = [ f"{span}" for span, score in spans ] table.append( f""" {escape(tokenizer.decode(input_tokens[i, j]))} ({activated[j - start] * 100:.2f}%) (...) {"".join(spans)} (...) ({i}, {j}) """ ) return f"""

Activated Examples of Group {router_index} / Expert {expert_id}

{"".join(table)}
""" @contextmanager def st_horizontal(): st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True) with st.container(): st.markdown( '', unsafe_allow_html=True, ) yield col1, col2 = st.columns(2) with col1: router_groups = [f"Routing Group {i}" for i in range(len(examples_tables))] router_index = st.selectbox("Expert Routing Group", router_groups, index=4) with col2: expert_id = st.number_input("Expert Index", 0, len(examples_tables[0]), 54136) with st_horizontal(): show_btn = st.button("Show") random_btn = st.button("Random") if show_btn or random_btn: router_index = router_groups.index(router_index) if random_btn: expert_id = random.choice(candidates[router_index]) st.html(render_routing_examples_in_html(router_index, expert_id))