virus_explorer / app.py
Hack90's picture
Update app.py
d9d2b55 verified
raw
history blame
16.7 kB
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from shiny import render
from shiny.express import input, output, ui
from utils import (
filter_and_select,
plot_2d_comparison,
plot_color_square,
wens_method_heatmap,
plot_fcgr,
plot_persistence_homology,
plot_distrobutions
)
import os
import seaborn as sns
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
############################################################# Virus Dataset ########################################################
#ds = load_dataset('Hack90/virus_tiny')
df = pd.read_parquet('virus_ds.parquet')
virus = df['Organism_Name'].unique()
virus = {v: v for v in virus}
df_new = pd.read_parquet("distro.parquet", columns= ['organism_name'])['organism_name'].tolist()
MASTER_DF = pd.read_parquet("distro.parquet")
virus_new = {v: v for v in df_new}
loss_typesss = pd.read_csv("training_data_5.csv")['loss_type'].unique().tolist()
model_typesss = pd.read_csv("training_data_5.csv")['model_type'].unique().tolist()
param_typesss = pd.read_csv("training_data_5.csv")['param_type'].unique().tolist()
############################################################# Filter and Select ########################################################
def filter_and_select(group):
if len(group) >= 3:
return group.head(3)
############################################################# UI #################################################################
ui.page_opts(fillable=True)
with ui.navset_card_tab(id="tab"):
with ui.nav_panel("Viral Macrostructure"):
ui.panel_title("Do viruses have underlying structure?")
with ui.layout_columns():
with ui.card():
ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None)
with ui.card():
ui.input_selectize(
"plot_type_macro",
"Select your method:",
["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
multiple=False,
selected=None,
)
@render.plot()
def plot_macro():
df = pd.read_parquet("virus_ds.parquet")
df = df[df["Organism_Name"].isin(input.virus_selector())]
grouped = df.groupby("Organism_Name")["Sequence"].apply(list)
plot_type = input.plot_type_macro()
if plot_type == "2D Line":
return plot_2d_comparison(grouped, grouped.index)
elif plot_type == "ColorSquare":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique())
elif plot_type == "Wens Method":
return wens_method_heatmap(df, df["Organism_Name"].unique())
elif plot_type == "Chaos Game Representation":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique())
elif plot_type == "Persistant Homology":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"])
with ui.nav_panel("Viral Genome Distributions"):
ui.panel_title("How does sequence distribution vary for a specie?")
with ui.layout_columns():
with ui.card():
ui.input_selectize("virus_selector_1", "Select your viruses:", virus_new, multiple=True, selected=None)
# with ui.card():
ui.input_selectize(
"plot_type_distro",
"Select your distrobution variance view:",
["Variance across bp", "Standard deviation across bp", "Full Genome Distrobution"],
multiple=False,
selected="Full Genome Distrobution",
)
@render.plot()
def plot_distro_new():
import seaborn as sns
plot_type = input.plot_type_distro()
if plot_type == "Full Genome Distrobution":
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
df = df.explode('charts').copy()
ax = sns.histplot(data=df, x='charts', hue='organism_name', stat='density')
ax.set_title("Distribution")
ax.set_xlabel("Distance from mean")
ax.set_ylabel("Density")
return ax
elif plot_type == "Standard deviation across bp":
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
dfs = []
for organism in input.virus_selector_1():
df_tiny = df[df['organism_name'] == organism].copy()
y = df_tiny['std'].values[0].tolist()
x = [x for x in range(len(y))]
df_tiny = pd.DataFrame()
df_tiny['y'] = y
df_tiny['x'] = x
df_tiny['organism'] = organism
dfs.append(df_tiny)
df_k = pd.DataFrame()
df_k = pd.concat(dfs)
df_k = df_k.explode(column =['x', 'y']).copy()
ax = sns.lineplot(data=df_k, x='x',y = 'y', hue='organism')
ax.set_title("Standard Deviation across basepairs")
ax.set_xlabel("Basepair")
ax.set_ylabel("Std")
return ax
elif plot_type == "Variance across bp":
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
dfs = []
for organism in input.virus_selector_1():
df_tiny = df[df['organism_name'] == organism].copy()
y = df_tiny['var'].values[0].tolist()
x = [x for x in range(len(y))]
df_tiny = pd.DataFrame()
df_tiny['y'] = y
df_tiny['x'] = x
df_tiny['organism'] = organism
dfs.append(df_tiny)
df_k = pd.DataFrame()
df_k = pd.concat(dfs)
df_k = df_k.explode(column =['x', 'y']).copy()
ax = sns.lineplot(data=df_k, x='x',y = 'y', hue='organism')
ax.set_title("Variance across basepairs")
ax.set_xlabel("Basepair")
ax.set_ylabel("Variance")
return ax
with ui.nav_panel("Viral Microstructure"):
ui.panel_title("Kmer Distribution")
with ui.layout_columns():
with ui.card():
ui.input_slider("kmer", "kmer", 0, 10, 4)
ui.input_slider("top_k", "top:", 0, 1000, 15)
ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None)
@render.plot()
def plot_micro():
df = pd.read_csv("kmers.csv")
k = input.kmer()
top_k = input.top_k()
plot_type = input.plot_type()
if k > 0:
df = df[df["k"] == k].head(top_k)
fig, ax = plt.subplots()
if plot_type == "count":
ax.bar(df["kmer"], df["count"])
ax.set_ylabel("Count")
elif plot_type == "percentage":
ax.bar(df["kmer"], df["percent"] * 100)
ax.set_ylabel("Percentage")
ax.set_title(f"Most common {k}-mers")
ax.set_xlabel("K-mer")
ax.set_xticklabels(df["kmer"], rotation=90)
return fig
with ui.nav_panel("Viral Model Training"):
ui.panel_title("Does context size matter for a nucleotide model?")
def plot_loss_rates(df, model_type):
x = np.linspace(0, 1, 1000)
loss_rates = []
labels = ["32", "64", "128", "256", "512", "1024"]
df = df.drop(columns=["Step"])
for col in df.columns:
y = df[col].dropna().astype("float", errors="ignore").values
f = interp1d(np.linspace(0, 1, len(y)), y)
loss_rates.append(f(x))
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(x, loss_rate, label=labels[i])
ax.legend()
ax.set_title(f"Loss rates for a {model_type} parameter model across context windows")
ax.set_xlabel("Training steps")
ax.set_ylabel("Loss rate")
return fig
@render.plot()
def plot_context_size_scaling():
df = pd.read_csv("14m.csv")
fig = plot_loss_rates(df, "14M")
if fig:
return fig
with ui.nav_panel("Model loss analysis"):
ui.panel_title("Paper stuff")
with ui.card():
ui.input_selectize(
"param_type",
"Select Param Type:",
param_typesss,
multiple=True,
)
ui.input_selectize(
"model_type",
"Select Model Type:",
model_typesss,
multiple=True,
selected=["pythia", "denseformer"],
)
ui.input_selectize(
"loss_type",
"Select Loss Type:",
loss_typesss,
multiple=True,
selected=["compliment", "cross_entropy", "headless"],
)
def plot_loss_rates_model(df, param_types, loss_types, model_types):
x = np.linspace(0, 1, 1000)
loss_rates = []
labels = []
for param_type in param_types:
for loss_type in loss_types:
for model_type in model_types:
y = df[
(df["param_type"] == float(param_type))
& (df["loss_type"] == loss_type)
& (df["model_type"] == model_type)
]["loss_interp"].values
if len(y) > 0:
f = interp1d(np.linspace(0, 1, len(y)), y)
loss_rates.append(f(x))
labels.append(f"{param_type}_{loss_type}_{model_type}")
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(x, loss_rate, label=labels[i])
ax.legend()
ax.set_xlabel("Training steps")
ax.set_ylabel("Loss rate")
return fig
@render.plot()
def plot_model_scaling():
df = pd.read_csv("training_data_5.csv")
df = df[df["epoch_interp"] > 0.035]
fig = plot_loss_rates_model(
df, input.param_type(), input.loss_type(), input.model_type()
)
if fig:
return fig
with ui.nav_panel("Scaling Laws"):
ui.panel_title("Params & Losses")
with ui.card():
ui.input_selectize(
"model_type_scale",
"Select Model Type:",
model_typesss,
multiple=True,
selected=["evo", "denseformer"],
)
ui.input_selectize(
"loss_type_scale",
"Select Loss Type:",
loss_typesss,
multiple=True,
selected=["cross_entropy"],
)
def plot_loss_rates_model_scale(df, loss_type, model_types):
df = df[df["loss_type"] == loss_type[0]]
params = []
loss_rates = []
labels = []
for model_type in model_types:
df_new = df[df["model_type"] == model_type]
losses = []
params_model = []
for paramy in df_new["num_params"].unique():
loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min()
par = int(paramy)
losses.append(loss)
params_model.append(par)
df_reorder = pd.DataFrame({"loss": losses, "params": params_model})
df_reorder = df_reorder.sort_values(by="params")
loss_rates.append(df_reorder["loss"].to_list())
params.append(df_reorder["params"].to_list())
labels.append(model_type)
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(params[i], loss_rate, label=labels[i])
ax.legend()
ax.set_xlabel("Params")
ax.set_ylabel("Loss")
return fig
@render.plot()
def plot_big_boy_model():
df = pd.read_csv("training_data_5.csv")
fig = plot_loss_rates_model_scale(
df, input.loss_type_scale(), input.model_type_scale()
)
if fig:
return fig
with ui.nav_panel("Logits View"):
ui.panel_title("Logits et all")
with ui.card():
ui.input_selectize(
"model_bigness",
"Select Model size:",
["14", "31", "70", "160", "410"],
multiple=True,
selected=["70", "160"],
)
ui.input_selectize(
"loss_loss_loss",
"Select Loss Type:",
["compliment", "cross_entropy", "headless", "2d_representation_GaussianPlusCE", "2d_representation_MSEPlusCE"],
multiple=True,
selected=["cross_entropy"],
)
ui.input_selectize(
"logits_select",
"Select logits:",
["1", "2", "3", "4", "5", "6", "7", "8"],
multiple=True,
selected=["6"],
)
def plot_logits_representation(model_bigness, loss_type, logits):
num_rows = 2 # Number of rows in the subplot grid
num_cols = len(logits) # Number of columns based on the number of selected logits
fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 10))
# axs = axs.flatten() # Flatten axs to handle 1D indexing
for size in model_bigness:
for loss in loss_type:
file_name = f"virus_pythia_{size}_1024_{loss}_logit_cumsums.npy"
if os.path.exists(file_name):
data = np.load(file_name, allow_pickle=True).item()
for k, logit in enumerate(logits):
if len(logits) == 1:
logit_index = int(logit) - 1
axs[0].plot(data['lm_logits_y_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
axs[0].plot(data['shift_labels_y_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
axs[0].set_title(f'Logit: {logit}- Single')
axs[0].legend()
axs[1].plot(data['lm_logits_y_full_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
axs[1].plot(data['shift_labels_y_full_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
axs[1].set_title(f'Logit: {logit} - Full')
axs[1].legend()
else:
print(f"File not found: {file_name}")
for k in range(len(logits), num_cols):
fig.delaxes(axs[k]) # Remove any extra subplots if fewer logits are selected
plt.tight_layout()
return fig
@render.plot()
def plot_logits_representation_ui():
fig = plot_logits_representation(
input.model_bigness(), input.loss_loss_loss(), input.logits_select()
)
if fig:
return fig