import pandas as pd import numpy as np import matplotlib.pyplot as plt from scipy.interpolate import interp1d from shiny import render from shiny.express import input, output, ui from utils import ( filter_and_select, plot_2d_comparison, plot_color_square, wens_method_heatmap, plot_fcgr, plot_persistence_homology, plot_distrobutions, process_data_sub_specie ) import os import seaborn as sns import matplotlib as mpl mpl.rcParams.update(mpl.rcParamsDefault) ############################################################# Virus Dataset ######################################################## #ds = load_dataset('Hack90/virus_tiny') df = pd.read_parquet('virus_ds.parquet') virus = df['Organism_Name'].unique() virus = {v: v for v in virus} df_new = pd.read_parquet("distro.parquet", columns= ['organism_name'])['organism_name'].tolist() MASTER_DF = pd.read_parquet("distro.parquet") virus_new = {v: v for v in df_new} loss_typesss = pd.read_csv("training_data_5.csv")['loss_type'].unique().tolist() model_typesss = pd.read_csv("training_data_5.csv")['model_type'].unique().tolist() param_typesss = pd.read_csv("training_data_5.csv")['param_type'].unique().tolist() ############################################################# Filter and Select ######################################################## def filter_and_select(group): if len(group) >= 3: return group.head(3) ############################################################# UI ################################################################# ui.page_opts(fillable=True) with ui.navset_card_tab(id="tab"): with ui.nav_panel("Viral Macrostructure"): ui.panel_title("Do viruses have underlying structure?") with ui.layout_columns(): with ui.card(): ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None) with ui.card(): ui.input_selectize( "plot_type_macro", "Select your method:", ["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"], multiple=False, selected=None, ) @render.plot() def plot_macro(): df = pd.read_parquet("virus_ds.parquet") df = df[df["Organism_Name"].isin(input.virus_selector())] grouped = df.groupby("Organism_Name")["Sequence"].apply(list) plot_type = input.plot_type_macro() if plot_type == "2D Line": return plot_2d_comparison(grouped, grouped.index) elif plot_type == "ColorSquare": filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique()) elif plot_type == "Wens Method": return wens_method_heatmap(df, df["Organism_Name"].unique()) elif plot_type == "Chaos Game Representation": filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique()) elif plot_type == "Persistant Homology": filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"]) with ui.nav_panel("Viral Genome Distributions"): ui.panel_title("How does sequence distribution vary for a specie?") with ui.layout_columns(): with ui.card(): ui.input_selectize("virus_selector_1", "Select your viruses:", virus_new, multiple=True, selected=None) # with ui.card(): ui.input_selectize( "plot_type_distro", "Select your distrobution variance view:", ["Variance across bp", "Standard deviation across bp", "Full Genome Distrobution"], multiple=False, selected="Full Genome Distrobution", ) @render.plot() def plot_distro_new(): import seaborn as sns plot_type = input.plot_type_distro() if plot_type == "Full Genome Distrobution": df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy() df = df.explode('charts').copy() ax = sns.histplot(data=df, x='charts', hue='organism_name', stat='density') ax.set_title("Distribution") ax.set_xlabel("Distance from mean") ax.set_ylabel("Density") return ax elif plot_type == "Standard deviation across bp": df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy() dfs = [] for organism in input.virus_selector_1(): df_tiny = df[df['organism_name'] == organism].copy() y = df_tiny['std'].values[0].tolist() x = [x for x in range(len(y))] df_tiny = pd.DataFrame() df_tiny['y'] = y df_tiny['x'] = x df_tiny['organism'] = organism dfs.append(df_tiny) df_k = pd.DataFrame() df_k = pd.concat(dfs) df_k = df_k.explode(column =['x', 'y']).copy() ax = sns.lineplot(data=df_k, x='x',y = 'y', hue='organism') ax.set_title("Standard Deviation across basepairs") ax.set_xlabel("Basepair") ax.set_ylabel("Std") return ax elif plot_type == "Variance across bp": df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy() dfs = [] for organism in input.virus_selector_1(): df_tiny = df[df['organism_name'] == organism].copy() y = df_tiny['var'].values[0].tolist() x = [x for x in range(len(y))] df_tiny = pd.DataFrame() df_tiny['y'] = y df_tiny['x'] = x df_tiny['organism'] = organism dfs.append(df_tiny) df_k = pd.DataFrame() df_k = pd.concat(dfs) df_k = df_k.explode(column =['x', 'y']).copy() ax = sns.lineplot(data=df_k, x='x',y = 'y', hue='organism') ax.set_title("Variance across basepairs") ax.set_xlabel("Basepair") ax.set_ylabel("Variance") return ax ########################################################### 2D Sub-Specie ##################################################################### with ui.nav_panel("Virus Sub-Specie"): ui.panel_title("Can we create sub-specie based on 2d representation? How does their 2D representation vary?") with ui.layout_columns(): with ui.card(): ui.input_selectize("virus_selector_2", "Select your viruses:", virus_new, multiple=False, selected='Human mastadenovirus B') ui.input_selectize( "plot_type_distro_sub", "Select plot type:", ["2D sub-specie - nominal", "2D sub-specie - detrended"], multiple=False, selected="2D sub-specie - nominal", ) ui.input_slider("variance", "variance from 2d rep", 1, 5, 3) @render.plot() def plot_sub_specie(): import seaborn as sns plot_type = input.plot_type_distro_sub() if plot_type == "2D sub-specie - nominal": df = pd.read_parquet('new_viral_dataset.parquet') df = process_data_sub_specie(df, input.virus_selector_2(), input.variance()) fig, ax = plt.subplots() # Get unique groups and assign colors groups = df['group'].unique() colors = plt.cm.rainbow(np.linspace(0, 1, len(groups))) color_dict = dict(zip(groups, colors)) # Iterate through rows and plot for _, row in df.iterrows(): x, y = zip(*row['two_d']) ax.plot(x, y, c=color_dict[row['group']]) # Remove duplicate labels handles, labels = ax.get_legend_handles_labels() by_label = dict(zip(labels, handles)) #ax.legend(by_label.values(), by_label.keys()) ax.set_title("Sub-specie") return ax with ui.nav_panel("Viral Microstructure"): ui.panel_title("Kmer Distribution") with ui.layout_columns(): with ui.card(): ui.input_slider("kmer", "kmer", 0, 10, 4) ui.input_slider("top_k", "top:", 0, 1000, 15) ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None) @render.plot() def plot_micro(): df = pd.read_csv("kmers.csv") k = input.kmer() top_k = input.top_k() plot_type = input.plot_type() if k > 0: df = df[df["k"] == k].head(top_k) fig, ax = plt.subplots() if plot_type == "count": ax.bar(df["kmer"], df["count"]) ax.set_ylabel("Count") elif plot_type == "percentage": ax.bar(df["kmer"], df["percent"] * 100) ax.set_ylabel("Percentage") ax.set_title(f"Most common {k}-mers") ax.set_xlabel("K-mer") ax.set_xticklabels(df["kmer"], rotation=90) return fig with ui.nav_panel("Viral Model Training"): ui.panel_title("Does context size matter for a nucleotide model?") def plot_loss_rates(df, model_type): x = np.linspace(0, 1, 1000) loss_rates = [] labels = ["32", "64", "128", "256", "512", "1024"] df = df.drop(columns=["Step"]) for col in df.columns: y = df[col].dropna().astype("float", errors="ignore").values f = interp1d(np.linspace(0, 1, len(y)), y) loss_rates.append(f(x)) fig, ax = plt.subplots() for i, loss_rate in enumerate(loss_rates): ax.plot(x, loss_rate, label=labels[i]) ax.legend() ax.set_title(f"Loss rates for a {model_type} parameter model across context windows") ax.set_xlabel("Training steps") ax.set_ylabel("Loss rate") return fig @render.plot() def plot_context_size_scaling(): df = pd.read_csv("14m.csv") fig = plot_loss_rates(df, "14M") if fig: return fig with ui.nav_panel("Model loss analysis"): ui.panel_title("Paper stuff") with ui.card(): ui.input_selectize( "param_type", "Select Param Type:", param_typesss, multiple=True, ) ui.input_selectize( "model_type", "Select Model Type:", model_typesss, multiple=True, selected=["pythia", "denseformer"], ) ui.input_selectize( "loss_type", "Select Loss Type:", loss_typesss, multiple=True, selected=["compliment", "cross_entropy", "headless"], ) def plot_loss_rates_model(df, param_types, loss_types, model_types): x = np.linspace(0, 1, 1000) loss_rates = [] labels = [] for param_type in param_types: for loss_type in loss_types: for model_type in model_types: y = df[ (df["param_type"] == float(param_type)) & (df["loss_type"] == loss_type) & (df["model_type"] == model_type) ]["loss_interp"].values if len(y) > 0: f = interp1d(np.linspace(0, 1, len(y)), y) loss_rates.append(f(x)) labels.append(f"{param_type}_{loss_type}_{model_type}") fig, ax = plt.subplots() for i, loss_rate in enumerate(loss_rates): ax.plot(x, loss_rate, label=labels[i]) ax.legend() ax.set_xlabel("Training steps") ax.set_ylabel("Loss rate") return fig @render.plot() def plot_model_scaling(): df = pd.read_csv("training_data_5.csv") df = df[df["epoch_interp"] > 0.035] fig = plot_loss_rates_model( df, input.param_type(), input.loss_type(), input.model_type() ) if fig: return fig with ui.nav_panel("Scaling Laws"): ui.panel_title("Params & Losses") with ui.card(): ui.input_selectize( "model_type_scale", "Select Model Type:", model_typesss, multiple=True, selected=["evo", "denseformer"], ) ui.input_selectize( "loss_type_scale", "Select Loss Type:", loss_typesss, multiple=True, selected=["cross_entropy"], ) def plot_loss_rates_model_scale(df, loss_type, model_types): df = df[df["loss_type"] == loss_type[0]] params = [] loss_rates = [] labels = [] for model_type in model_types: df_new = df[df["model_type"] == model_type] losses = [] params_model = [] for paramy in df_new["num_params"].unique(): loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min() par = int(paramy) losses.append(loss) params_model.append(par) df_reorder = pd.DataFrame({"loss": losses, "params": params_model}) df_reorder = df_reorder.sort_values(by="params") loss_rates.append(df_reorder["loss"].to_list()) params.append(df_reorder["params"].to_list()) labels.append(model_type) fig, ax = plt.subplots() for i, loss_rate in enumerate(loss_rates): ax.plot(params[i], loss_rate, label=labels[i]) ax.legend() ax.set_xlabel("Params") ax.set_ylabel("Loss") return fig @render.plot() def plot_big_boy_model(): df = pd.read_csv("training_data_5.csv") fig = plot_loss_rates_model_scale( df, input.loss_type_scale(), input.model_type_scale() ) if fig: return fig with ui.nav_panel("Logits View"): ui.panel_title("Logits et all") with ui.card(): ui.input_selectize( "model_bigness", "Select Model size:", ["14", "31", "70", "160", "410"], multiple=True, selected=["70", "160"], ) ui.input_selectize( "loss_loss_loss", "Select Loss Type:", ["compliment", "cross_entropy", "headless", "2d_representation_GaussianPlusCE", "2d_representation_MSEPlusCE"], multiple=True, selected=["cross_entropy"], ) ui.input_selectize( "logits_select", "Select logits:", ["1", "2", "3", "4", "5", "6", "7", "8"], multiple=True, selected=["6"], ) def plot_logits_representation(model_bigness, loss_type, logits): num_rows = 2 # Number of rows in the subplot grid num_cols = len(logits) # Number of columns based on the number of selected logits fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 10)) # axs = axs.flatten() # Flatten axs to handle 1D indexing for size in model_bigness: for loss in loss_type: file_name = f"virus_pythia_{size}_1024_{loss}_logit_cumsums.npy" if os.path.exists(file_name): data = np.load(file_name, allow_pickle=True).item() for k, logit in enumerate(logits): if len(logits) == 1: logit_index = int(logit) - 1 axs[0].plot(data['lm_logits_y_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}') axs[0].plot(data['shift_labels_y_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}') axs[0].set_title(f'Logit: {logit}- Single') axs[0].legend() axs[1].plot(data['lm_logits_y_full_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}') axs[1].plot(data['shift_labels_y_full_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}') axs[1].set_title(f'Logit: {logit} - Full') axs[1].legend() else: print(f"File not found: {file_name}") for k in range(len(logits), num_cols): fig.delaxes(axs[k]) # Remove any extra subplots if fewer logits are selected plt.tight_layout() return fig @render.plot() def plot_logits_representation_ui(): fig = plot_logits_representation( input.model_bigness(), input.loss_loss_loss(), input.logits_select() ) if fig: return fig