Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.interpolate import interp1d | |
from shiny import render | |
from shiny.express import input, output, ui | |
from utils import ( | |
filter_and_select, | |
plot_2d_comparison, | |
plot_color_square, | |
wens_method_heatmap, | |
plot_fcgr, | |
plot_persistence_homology, | |
plot_distrobutions | |
) | |
import os | |
import matplotlib as mpl | |
mpl.rcParams.update(mpl.rcParamsDefault) | |
############################################################# Virus Dataset ######################################################## | |
#ds = load_dataset('Hack90/virus_tiny') | |
df = pd.read_parquet('virus_ds.parquet') | |
virus = df['Organism_Name'].unique() | |
virus = {v: v for v in virus} | |
loss_typesss = pd.read_csv("training_data_5.csv")['loss_type'].unique().tolist() | |
model_typesss = pd.read_csv("training_data_5.csv")['model_type'].unique().tolist() | |
param_typesss = pd.read_csv("training_data_5.csv")['param_type'].unique().tolist() | |
############################################################# Filter and Select ######################################################## | |
def filter_and_select(group): | |
if len(group) >= 3: | |
return group.head(3) | |
############################################################# UI ################################################################# | |
ui.page_opts(fillable=True) | |
with ui.navset_card_tab(id="tab"): | |
with ui.nav_panel("Viral Macrostructure"): | |
ui.panel_title("Do viruses have underlying structure?") | |
with ui.layout_columns(): | |
with ui.card(): | |
ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None) | |
with ui.card(): | |
ui.input_selectize( | |
"plot_type_macro", | |
"Select your method:", | |
["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"], | |
multiple=False, | |
selected=None, | |
) | |
def plot_macro(): | |
df = pd.read_parquet("virus_ds.parquet") | |
df = df[df["Organism_Name"].isin(input.virus_selector())] | |
grouped = df.groupby("Organism_Name")["Sequence"].apply(list) | |
plot_type = input.plot_type_macro() | |
if plot_type == "2D Line": | |
return plot_2d_comparison(grouped, grouped.index) | |
elif plot_type == "ColorSquare": | |
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) | |
return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique()) | |
elif plot_type == "Wens Method": | |
return wens_method_heatmap(df, df["Organism_Name"].unique()) | |
elif plot_type == "Chaos Game Representation": | |
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) | |
return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique()) | |
elif plot_type == "Persistant Homology": | |
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) | |
return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"]) | |
with ui.nav_panel("Viral Genome Distributions"): | |
ui.panel_title("How does sequence distribution vary across sequence length?") | |
with ui.layout_columns(): | |
with ui.card(): | |
ui.input_selectize("virus_selector_1", "Select your viruses:", virus, multiple=True, selected=None) | |
with ui.card(): | |
ui.input_slider( | |
"basepair","Select basepair",0, 1000, 15 | |
) | |
def plot_distro(): | |
df = pd.read_parquet("virus_ds.parquet") | |
df = df[df["Organism_Name"].isin(input.virus_selector())] | |
grouped = df.groupby("Organism_Name")["Sequence"].apply(list) | |
return plot_distrobutions(grouped, grouped.index, input.basepair()) | |
with ui.nav_panel("Viral Microstructure"): | |
ui.panel_title("Kmer Distribution") | |
with ui.layout_columns(): | |
with ui.card(): | |
ui.input_slider("kmer", "kmer", 0, 10, 4) | |
ui.input_slider("top_k", "top:", 0, 1000, 15) | |
ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None) | |
def plot_micro(): | |
df = pd.read_csv("kmers.csv") | |
k = input.kmer() | |
top_k = input.top_k() | |
plot_type = input.plot_type() | |
if k > 0: | |
df = df[df["k"] == k].head(top_k) | |
fig, ax = plt.subplots() | |
if plot_type == "count": | |
ax.bar(df["kmer"], df["count"]) | |
ax.set_ylabel("Count") | |
elif plot_type == "percentage": | |
ax.bar(df["kmer"], df["percent"] * 100) | |
ax.set_ylabel("Percentage") | |
ax.set_title(f"Most common {k}-mers") | |
ax.set_xlabel("K-mer") | |
ax.set_xticklabels(df["kmer"], rotation=90) | |
return fig | |
with ui.nav_panel("Viral Model Training"): | |
ui.panel_title("Does context size matter for a nucleotide model?") | |
def plot_loss_rates(df, model_type): | |
x = np.linspace(0, 1, 1000) | |
loss_rates = [] | |
labels = ["32", "64", "128", "256", "512", "1024"] | |
df = df.drop(columns=["Step"]) | |
for col in df.columns: | |
y = df[col].dropna().astype("float", errors="ignore").values | |
f = interp1d(np.linspace(0, 1, len(y)), y) | |
loss_rates.append(f(x)) | |
fig, ax = plt.subplots() | |
for i, loss_rate in enumerate(loss_rates): | |
ax.plot(x, loss_rate, label=labels[i]) | |
ax.legend() | |
ax.set_title(f"Loss rates for a {model_type} parameter model across context windows") | |
ax.set_xlabel("Training steps") | |
ax.set_ylabel("Loss rate") | |
return fig | |
def plot_context_size_scaling(): | |
df = pd.read_csv("14m.csv") | |
fig = plot_loss_rates(df, "14M") | |
if fig: | |
return fig | |
with ui.nav_panel("Model loss analysis"): | |
ui.panel_title("Paper stuff") | |
with ui.card(): | |
ui.input_selectize( | |
"param_type", | |
"Select Param Type:", | |
param_typesss, | |
multiple=True, | |
) | |
ui.input_selectize( | |
"model_type", | |
"Select Model Type:", | |
model_typesss, | |
multiple=True, | |
selected=["pythia", "denseformer"], | |
) | |
ui.input_selectize( | |
"loss_type", | |
"Select Loss Type:", | |
loss_typesss, | |
multiple=True, | |
selected=["compliment", "cross_entropy", "headless"], | |
) | |
def plot_loss_rates_model(df, param_types, loss_types, model_types): | |
x = np.linspace(0, 1, 1000) | |
loss_rates = [] | |
labels = [] | |
for param_type in param_types: | |
for loss_type in loss_types: | |
for model_type in model_types: | |
y = df[ | |
(df["param_type"] == float(param_type)) | |
& (df["loss_type"] == loss_type) | |
& (df["model_type"] == model_type) | |
]["loss_interp"].values | |
if len(y) > 0: | |
f = interp1d(np.linspace(0, 1, len(y)), y) | |
loss_rates.append(f(x)) | |
labels.append(f"{param_type}_{loss_type}_{model_type}") | |
fig, ax = plt.subplots() | |
for i, loss_rate in enumerate(loss_rates): | |
ax.plot(x, loss_rate, label=labels[i]) | |
ax.legend() | |
ax.set_xlabel("Training steps") | |
ax.set_ylabel("Loss rate") | |
return fig | |
def plot_model_scaling(): | |
df = pd.read_csv("training_data_5.csv") | |
df = df[df["epoch_interp"] > 0.035] | |
fig = plot_loss_rates_model( | |
df, input.param_type(), input.loss_type(), input.model_type() | |
) | |
if fig: | |
return fig | |
with ui.nav_panel("Scaling Laws"): | |
ui.panel_title("Params & Losses") | |
with ui.card(): | |
ui.input_selectize( | |
"model_type_scale", | |
"Select Model Type:", | |
model_typesss, | |
multiple=True, | |
selected=["evo", "denseformer"], | |
) | |
ui.input_selectize( | |
"loss_type_scale", | |
"Select Loss Type:", | |
loss_typesss, | |
multiple=True, | |
selected=["cross_entropy"], | |
) | |
def plot_loss_rates_model_scale(df, loss_type, model_types): | |
df = df[df["loss_type"] == loss_type[0]] | |
params = [] | |
loss_rates = [] | |
labels = [] | |
for model_type in model_types: | |
df_new = df[df["model_type"] == model_type] | |
losses = [] | |
params_model = [] | |
for paramy in df_new["num_params"].unique(): | |
loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min() | |
par = int(paramy) | |
losses.append(loss) | |
params_model.append(par) | |
df_reorder = pd.DataFrame({"loss": losses, "params": params_model}) | |
df_reorder = df_reorder.sort_values(by="params") | |
loss_rates.append(df_reorder["loss"].to_list()) | |
params.append(df_reorder["params"].to_list()) | |
labels.append(model_type) | |
fig, ax = plt.subplots() | |
for i, loss_rate in enumerate(loss_rates): | |
ax.plot(params[i], loss_rate, label=labels[i]) | |
ax.legend() | |
ax.set_xlabel("Params") | |
ax.set_ylabel("Loss") | |
return fig | |
def plot_big_boy_model(): | |
df = pd.read_csv("training_data_5.csv") | |
fig = plot_loss_rates_model_scale( | |
df, input.loss_type_scale(), input.model_type_scale() | |
) | |
if fig: | |
return fig | |
with ui.nav_panel("Logits View"): | |
ui.panel_title("Logits et all") | |
with ui.card(): | |
ui.input_selectize( | |
"model_bigness", | |
"Select Model size:", | |
["14", "31", "70", "160", "410"], | |
multiple=True, | |
selected=["70", "160"], | |
) | |
ui.input_selectize( | |
"loss_loss_loss", | |
"Select Loss Type:", | |
["compliment", "cross_entropy", "headless", "2d_representation_GaussianPlusCE", "2d_representation_MSEPlusCE"], | |
multiple=True, | |
selected=["cross_entropy"], | |
) | |
ui.input_selectize( | |
"logits_select", | |
"Select logits:", | |
["1", "2", "3", "4", "5", "6", "7", "8"], | |
multiple=True, | |
selected=["6"], | |
) | |
def plot_logits_representation(model_bigness, loss_type, logits): | |
num_rows = 2 # Number of rows in the subplot grid | |
num_cols = len(logits) # Number of columns based on the number of selected logits | |
fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 10)) | |
# axs = axs.flatten() # Flatten axs to handle 1D indexing | |
for size in model_bigness: | |
for loss in loss_type: | |
file_name = f"virus_pythia_{size}_1024_{loss}_logit_cumsums.npy" | |
if os.path.exists(file_name): | |
data = np.load(file_name, allow_pickle=True).item() | |
for k, logit in enumerate(logits): | |
if len(logits) == 1: | |
logit_index = int(logit) - 1 | |
axs[0].plot(data['lm_logits_y_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}') | |
axs[0].plot(data['shift_labels_y_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}') | |
axs[0].set_title(f'Logit: {logit}- Single') | |
axs[0].legend() | |
axs[1].plot(data['lm_logits_y_full_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}') | |
axs[1].plot(data['shift_labels_y_full_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}') | |
axs[1].set_title(f'Logit: {logit} - Full') | |
axs[1].legend() | |
else: | |
print(f"File not found: {file_name}") | |
for k in range(len(logits), num_cols): | |
fig.delaxes(axs[k]) # Remove any extra subplots if fewer logits are selected | |
plt.tight_layout() | |
return fig | |
def plot_logits_representation_ui(): | |
fig = plot_logits_representation( | |
input.model_bigness(), input.loss_loss_loss(), input.logits_select() | |
) | |
if fig: | |
return fig | |