Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM | |
def get_model_structure(model_id): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="cpu", | |
) | |
structure = {k: v.shape for k, v in model.state_dict().items()} | |
return structure | |
def compare_structures(struct1, struct2): | |
keys1 = set(struct1.keys()) | |
keys2 = set(struct2.keys()) | |
all_keys = keys1.union(keys2) | |
diff = [] | |
for key in all_keys: | |
shape1 = struct1.get(key) | |
shape2 = struct2.get(key) | |
if shape1 != shape2: | |
diff.append((key, shape1, shape2)) | |
return diff | |
def display_diff(diff): | |
left_lines = [] | |
right_lines = [] | |
for key, shape1, shape2 in diff: | |
left_lines.append(f"{key}: {shape1}") | |
right_lines.append(f"{key}: {shape2}") | |
left_html = "<br>".join(left_lines) | |
right_html = "<br>".join(right_lines) | |
return left_html, right_html | |
st.title("Model Structure Comparison Tool") | |
model_id1 = st.text_input("Enter the first HuggingFace Model ID") | |
model_id2 = st.text_input("Enter the second HuggingFace Model ID") | |
if model_id1 and model_id2: | |
struct1 = get_model_structure(model_id1) | |
struct2 = get_model_structure(model_id2) | |
diff = compare_structures(struct1, struct2) | |
left_html, right_html = display_diff(diff) | |
st.write("### Comparison Result") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("### Model 1") | |
st.markdown(left_html, unsafe_allow_html=True) | |
with col2: | |
st.write("### Model 2") | |
st.markdown(right_html, unsafe_allow_html=True) | |