import pandas as pd from sklearn.model_selection import train_test_split from fuson_plm.utils.logging import log_update def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20): # cluster with random state fixed for reproducible results log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})") X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state) # add benchmark representatives back to X_test log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test") X_test += benchmark_cluster_reps # assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean) assert len(X_train)==len(set(X_train)) assert len(X_test)==len(set(X_test)) return { 'X_train': X_train, 'X_test': X_test } def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50): # cluster with random state fixed for reproducible results log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})") X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1) log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})") X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2) # add benchmark representatives back to X_test log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test") X_test += benchmark_cluster_reps # assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean) assert len(X_train)==len(set(X_train)) assert len(X_val)==len(set(X_val)) assert len(X_test)==len(set(X_test)) return { 'X_train': X_train, 'X_val': X_val, 'X_test': X_test } def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50): """" Cluster-splitting method amenable to either train-test or train-val-test. For train-val-test, there are two splits. """ log_update("\nPerforming splits...") # Approx. 80/10/10 split X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)] # X, for splitting, does NOT include benchmark reps. We'll add these clusters to test. y = [0]*len(X) # y is a dummy array here; there are no values. split_dict = None if val_set: split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps, random_state_1 = random_state_1, random_state_2 = random_state_2, test_size_1 = test_size_1, test_size_2 = test_size_2) else: split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps, random_state = random_state_1, test_size = test_size_1) return split_dict def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None): """ Args: train_clusters (pd.DataFrame): val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set) test_clusters (pd.DataFrame): """ # Make grouped versions of these DataFrames for size analysis train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) if val_clusters is not None: val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) if test_clusters is not None: test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) # Calculate stats - clusters n_train_clusters = len(train_clustersgb) n_val_clusters, n_test_clusters = 0, 0 if val_clusters is not None: n_val_clusters = len(val_clustersgb) if test_clusters is not None: n_test_clusters = len(test_clustersgb) n_clusters = n_train_clusters + n_val_clusters + n_test_clusters assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb) if val_clusters is not None: assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb) if test_clusters is not None: assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb) train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2) if val_clusters is not None: val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2) if test_clusters is not None: test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2) # Calculate stats - proteins n_train_proteins = len(train_clusters) n_val_proteins, n_test_proteins = 0, 0 if val_clusters is not None: n_val_proteins = len(val_clusters) if test_clusters is not None: n_test_proteins = len(test_clusters) n_proteins = n_train_proteins + n_val_proteins + n_test_proteins assert len(train_clusters) == sum(train_clustersgb['member count']) if val_clusters is not None: assert len(val_clusters) == sum(val_clustersgb['member count']) if test_clusters is not None: assert len(test_clusters) == sum(test_clustersgb['member count']) train_protein_pcnt = round(100*n_train_proteins/n_proteins,2) if val_clusters is not None: val_protein_pcnt = round(100*n_val_proteins/n_proteins,2) if test_clusters is not None: test_protein_pcnt = round(100*n_test_proteins/n_proteins,2) # Print results log_update("\nCluster breakdown...") log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}") log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)") if val_clusters is not None: log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)") if test_clusters is not None: log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)") # Check for overlap in both sequence ID and sequence actual train_protein_ids = set(train_clusters['member seq_id']) train_protein_seqs = set(train_clusters['member seq']) if val_clusters is not None: val_protein_ids = set(val_clusters['member seq_id']) val_protein_seqs = set(val_clusters['member seq']) if test_clusters is not None: test_protein_ids = set(test_clusters['member seq_id']) test_protein_seqs = set(test_clusters['member seq']) # Print results log_update("\nChecking for overlap...") if (val_clusters is not None) and (test_clusters is not None): log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}") log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}") if (val_clusters is not None) and (test_clusters is None): log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}") log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}") if (val_clusters is None) and (test_clusters is not None): log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}") log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}") # Assert no sequence overlap if val_clusters is not None: assert len(train_protein_seqs.intersection(val_protein_seqs))==0 if test_clusters is not None: assert len(train_protein_seqs.intersection(test_protein_seqs))==0 if (val_clusters is not None) and (test_clusters is not None): assert len(val_protein_seqs.intersection(test_protein_seqs))==0 # Finally, check that there are only benchmark sequences in test - if there are benchmark sequences if not(benchmark_sequences is None): bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique()) bench_in_val, bench_in_test = 0, 0 if val_clusters is not None: bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique()) if test_clusters is not None: bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique()) # Assert this log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...") log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}") log_update(f"\tBenchmark sequences in train: {bench_in_train}") if val_clusters is not None: log_update(f"\tBenchmark sequences in val: {bench_in_val}") if test_clusters is not None: log_update(f"\tBenchmark sequences in test: {bench_in_test}") assert bench_in_train == bench_in_val == 0 assert bench_in_test == len(benchmark_sequences) def check_class_distributions(train_df, val_df, test_df, class_col='class'): """ Checks class distributions within train, val, and test sets. Expects input dataframes to have 'sequence' column and 'class' column """ train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'}) train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100 if val_df is not None: val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'}) val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100 test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'}) test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100 # concatenate so I can see them next to each other if val_df is not None: compare = pd.concat([train_vc, val_vc, test_vc], axis=1) compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x)) compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x)) else: compare = pd.concat([train_vc, test_vc], axis=1) compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x)) compare_str = compare.to_string(index=False) compare_str = "\t" + compare_str.replace("\n","\n\t") log_update(f"\nClass distribution:\n{compare_str}")