svincoff commited on
Commit
ffaff91
·
1 Parent(s): 1e6a1f0

adding utility files used throughout FusOn-pLM training and benchmarking

Browse files
fuson_plm/utils/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ This folder contains common functions for data cleaning, clustering, train-test splitting, visualization, embedding, and logging.
2
+
3
+ The functions in these scripts are used throughout the pository for training the main model, FusOn-pLM, as well as benchmarks.
fuson_plm/utils/__init__.py ADDED
File without changes
fuson_plm/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
fuson_plm/utils/__pycache__/clustering.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
fuson_plm/utils/__pycache__/constants.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
fuson_plm/utils/__pycache__/data_cleaning.cpython-310.pyc ADDED
Binary file (4.45 kB). View file
 
fuson_plm/utils/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
fuson_plm/utils/__pycache__/logging.cpython-310.pyc ADDED
Binary file (3.31 kB). View file
 
fuson_plm/utils/__pycache__/splitting.cpython-310.pyc ADDED
Binary file (6.95 kB). View file
 
fuson_plm/utils/__pycache__/visualizing.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
fuson_plm/utils/clustering.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from Bio import SeqIO
6
+ import shutil
7
+ from fuson_plm.utils.logging import open_logfile, log_update
8
+
9
+ def ensure_mmseqs_in_path(mmseqs_dir):
10
+ """
11
+ Checks if MMseqs2 is in the PATH. If it's not, add it. MMseqs2 will not run if this is not done correctly.
12
+
13
+ Args:
14
+ mmseqs_dir (str): Directory containing MMseqs2 binaries
15
+ """
16
+ mmseqs_bin = os.path.join(mmseqs_dir, 'mmseqs')
17
+
18
+ # Check if mmseqs is already in PATH
19
+ if shutil.which('mmseqs') is None:
20
+ # Export the MMseqs2 directory to PATH
21
+ os.environ['PATH'] = f"{mmseqs_dir}:{os.environ['PATH']}"
22
+ log_update(f"\tAdded {mmseqs_dir} to PATH")
23
+
24
+ def process_fasta(fasta_path):
25
+ fasta_sequences = SeqIO.parse(open(fasta_path),'fasta')
26
+ d = {}
27
+ for fasta in fasta_sequences:
28
+ id, sequence = fasta.id, str(fasta.seq)
29
+
30
+ d[id] = sequence
31
+
32
+ return d
33
+
34
+ def analyze_clustering_result(input_fasta: str, tsv_path: str):
35
+ """
36
+ Args:
37
+ input_fasta (str): path to input fasta file
38
+ """
39
+
40
+ # Process input fasta
41
+ input_d = process_fasta(input_fasta)
42
+
43
+ # Process clusters.tsv
44
+ clusters = pd.read_csv(f'{tsv_path}',sep='\t',header=None)
45
+ clusters = clusters.rename(columns={
46
+ 0: 'representative seq_id',
47
+ 1: 'member seq_id'
48
+ })
49
+
50
+ clusters['representative seq'] = clusters['representative seq_id'].apply(lambda seq_id: input_d[seq_id])
51
+ clusters['member seq'] = clusters['member seq_id'].apply(lambda seq_id: input_d[seq_id])
52
+
53
+ # Sort them so that splitting results are reproducible
54
+ clusters = clusters.sort_values(by=['representative seq_id','member seq_id'],ascending=True).reset_index(drop=True)
55
+
56
+ return clusters
57
+
58
+ def make_fasta(sequences: dict, fasta_path: str):
59
+ """
60
+ Makes a fasta file from sequences, where the key is the header and the value is the sequence.
61
+
62
+ Args:
63
+ sequences (dict): A dictionary where the key is the header and the value is the sequence.
64
+
65
+ Returns:
66
+ str: The path to the fasta file.
67
+ """
68
+ with open(fasta_path, 'w') as f:
69
+ for header, sequence in sequences.items():
70
+ f.write(f'>{header}\n{sequence}\n')
71
+
72
+ return fasta_path
73
+
74
+ def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mode=0, cluster_mode=0, path_to_mmseqs='fuson_plm/mmseqs'):
75
+ """
76
+ Runs MMSeqs2 clustering using easycluster module
77
+
78
+ Args:
79
+ input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence....
80
+ output_dir (str): path to output dir for clustering results
81
+ min_seq_id (float): number [0,1] representing --min-seq-id in cluster command
82
+ c (float): nunber [0,1] representing -c in cluster command
83
+ cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command
84
+ cluster_mode (int): number 0, 1, or 2 representing --cluster-mode in cluster command
85
+
86
+ """
87
+ # Get mmseqs dir
88
+ log_update("\nRunning MMSeqs clustering...")
89
+ mmseqs_dir = os.path.join(path_to_mmseqs[0:path_to_mmseqs.index('/mmseqs')], 'mmseqs/bin')
90
+
91
+ # Ensure MMseqs2 is in the PATH
92
+ ensure_mmseqs_in_path(mmseqs_dir)
93
+
94
+ # Define paths for MMseqs2
95
+ mmseqs_bin = "mmseqs" # Ensure this is in your PATH or provide the full path to mmseqs binary
96
+
97
+ # Create the output directory
98
+ os.makedirs(output_dir, exist_ok=True)
99
+
100
+ # Run MMseqs2 easy-cluster
101
+ cmd_easy_cluster = [
102
+ mmseqs_bin, "easy-cluster", input_fasta, os.path.join(output_dir, "mmseqs"), output_dir,
103
+ "--min-seq-id", str(min_seq_id),
104
+ "-c", str(c),
105
+ "--cov-mode", str(cov_mode),
106
+ "--cluster-mode", str(cluster_mode),
107
+ "--dbtype", "1"
108
+ ]
109
+
110
+ # Write the command to a log file
111
+ log_update("\n\tCommand entered to MMSeqs2:")
112
+ log_update("\t" + " ".join(cmd_easy_cluster) + "\n")
113
+
114
+ subprocess.run(cmd_easy_cluster, check=True)
115
+
116
+ log_update(f"Clustering completed. Results are in {output_dir}")
117
+
118
+ def cluster_summary(clusters: pd.DataFrame):
119
+ """
120
+ Summarizes how many clusters were formed, how big they are, etc ...
121
+ """
122
+ grouped_clusters = clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
123
+ assert len(grouped_clusters) == len(clusters['representative seq_id'].unique()) # make sure number of cluster reps = # grouped clusters
124
+
125
+ total_seqs = sum(grouped_clusters['member count'])
126
+ log_update(f"Created {len(grouped_clusters)} clusters of {total_seqs} sequences")
127
+ log_update(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1")
128
+ csize1_seqs = sum(grouped_clusters[grouped_clusters['member count']==1]['member count'])
129
+ log_update(f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)")
130
+
131
+ log_update(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1")
132
+ csizeg1_seqs = sum(grouped_clusters[grouped_clusters['member count']>1]['member count'])
133
+ log_update(f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)")
134
+ log_update(f"\tlargest cluster: {max(grouped_clusters['member count'])}")
135
+
136
+ log_update("\nCluster size breakdown below...")
137
+
138
+ value_counts = grouped_clusters['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
139
+ log_update(value_counts.sort_values(by='cluster size (n_members)',ascending=True).to_string(index=False))
fuson_plm/utils/constants.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Cleaning Parameters
2
+ # TCGA abbreviations for cancer. From https://gdc.cancer.gov/resources-tcga-users/tcga-code-tables/tcga-study-abbreviations
3
+ TCGA_CODES = {
4
+ 'LAML': 'Acute Myeloid Leukemia',
5
+ 'ACC': 'Adrenocortical carcinoma',
6
+ 'BLCA': 'Bladder Urothelial Carcinoma',
7
+ 'LGG': 'Brain Lower Grade Glioma',
8
+ 'BRCA': 'Breast invasive carcinoma',
9
+ 'CESC': 'Cervical squamous cell carcinoma and endocervical adenocarcinoma',
10
+ 'CHOL': 'Cholangiocarcinoma',
11
+ 'LCML': 'Chronic Myelogenous Leukemia',
12
+ 'COAD': 'Colon adenocarcinoma',
13
+ 'CNTL': 'Controls',
14
+ 'ESCA': 'Esophageal carcinoma',
15
+ 'FPPP': 'FFPE Pilot Phase II',
16
+ 'GBM': 'Glioblastoma multiforme',
17
+ 'HNSC': 'Head and Neck squamous cell carcinoma',
18
+ 'KICH': 'Kidney Chromophobe',
19
+ 'KIRC': 'Kidney renal clear cell carcinoma',
20
+ 'KIRP': 'Kidney renal papillary cell carcinoma',
21
+ 'LIHC': 'Liver hepatocellular carcinoma',
22
+ 'LUAD': 'Lung adenocarcinoma',
23
+ 'LUSC': 'Lung squamous cell carcinoma',
24
+ 'DLBC': 'Lymphoid Neoplasm Diffuse Large B-cell Lymphoma',
25
+ 'MESO': 'Mesothelioma',
26
+ 'MISC': 'Miscellaneous',
27
+ 'OV': 'Ovarian serous cystadenocarcinoma',
28
+ 'PAAD': 'Pancreatic adenocarcinoma',
29
+ 'PCPG': 'Pheochromocytoma and Paraganglioma',
30
+ 'PRAD': 'Prostate adenocarcinoma',
31
+ 'READ': 'Rectum adenocarcinoma',
32
+ 'SARC': 'Sarcoma',
33
+ 'SKCM': 'Skin Cutaneous Melanoma',
34
+ 'STAD': 'Stomach adenocarcinoma',
35
+ 'TGCT': 'Testicular Germ Cell Tumors',
36
+ 'THYM': 'Thymoma',
37
+ 'THCA': 'Thyroid carcinoma',
38
+ 'UCS': 'Uterine Carcinosarcoma',
39
+ 'UCEC': 'Uterine Corpus Endometrial Carcinoma',
40
+ 'UVM': 'Uveal Melanoma'
41
+ }
42
+
43
+ FODB_CODES = {
44
+ 'ACC': 'Adenoid cystic carcinoma',
45
+ 'ALL': 'Acute Lymphoid Leukemia',
46
+ 'AML': 'Acute Myeloid Leukemia',
47
+ 'BALL': 'B-cell acute lymphoblastic leukemia',
48
+ 'BLCA': 'Bladder Urothelial Carcinoma',
49
+ 'BRCA': 'Breast invasive carcinoma',
50
+ 'CESC': 'Cervical squamous cell carcinoma and endocervical adenocarcinoma',
51
+ 'CHOL': 'Cholangiocarcinoma',
52
+ 'EPD': 'Ependymoma',
53
+ 'HGG': 'High-grade glioma',
54
+ 'HNSC': 'Head and Neck squamous cell carcinoma',
55
+ 'KIRC': 'Kidney renal clear cell carcinoma',
56
+ 'LGG': 'Low-grade glioma',
57
+ 'LUAD': 'Lung adenocarcinoma',
58
+ 'LUSC': 'Lung squamous cell carcinoma',
59
+ 'MEL': 'Melanoma',
60
+ 'MESO': 'Mesothelioma',
61
+ 'NBL': 'Neuroblastoma',
62
+ 'OS': 'Osteosarcoma',
63
+ 'OV': 'Ovarian serous cystadenocarcinoma',
64
+ 'PCPG': 'Pheochromocytoma and Paraganglioma',
65
+ 'PRAD': 'Prostate adenocarcinoma',
66
+ 'READ': 'Rectum adenocarcinoma',
67
+ 'RHB': 'Rhabdomyosarcoma',
68
+ 'SARC': 'Sarcoma',
69
+ 'STAD': 'Stomach adenocarcinoma',
70
+ 'TALL': 'T-cell acute lymphoblastic leukemia',
71
+ 'THYM': 'Thymoma',
72
+ 'UCEC': 'Uterine Corpus Endometrial Carcinoma',
73
+ 'UCS': 'Uterine Carcinosarcoma',
74
+ 'UVM': 'Uveal Melanoma',
75
+ 'WLM': 'Wilms tumor'
76
+ }
77
+
78
+ VALID_AAS = {'A',
79
+ 'R',
80
+ 'N',
81
+ 'D',
82
+ 'C',
83
+ 'E',
84
+ 'Q',
85
+ 'G',
86
+ 'H',
87
+ 'I',
88
+ 'L',
89
+ 'K',
90
+ 'M',
91
+ 'F',
92
+ 'P',
93
+ 'S',
94
+ 'T',
95
+ 'W',
96
+ 'Y',
97
+ 'V'}
98
+
99
+ DELIMITERS = {',',
100
+ ';',
101
+ '|',
102
+ '\t',
103
+ ' ',
104
+ ':',
105
+ '-',
106
+ '/',
107
+ '\\',
108
+ '\n'}
fuson_plm/utils/data_cleaning.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from fuson_plm.utils.logging import log_update
4
+
5
+ def clean_rows_and_cols(df: pd.Series) -> pd.Series:
6
+ """
7
+ Deletes empty rows and columns
8
+
9
+ Args:
10
+ df (pd.Series): input DatFrame to be cleaned
11
+
12
+ Returns:
13
+ pd.Series: cleaned DataFrame
14
+ """
15
+ # Delete rows with no data
16
+ log_update(f"\trow cleaning...\n\t\toriginal # rows: {len(df)}")
17
+ log_update("\t\tdropping rows where all entries are np.nan...")
18
+ df = df.dropna(how='all')
19
+ log_update(f"\t\tnew # rows: {len(df)}")
20
+
21
+ # Delete columns with no data
22
+ log_update(f"\tcolumn cleaning...\n\t\toriginal # columns: {len(df.columns)}")
23
+ log_update("\t\tdropping columns where all entries are np.nan...")
24
+ df = df.dropna(axis=1,how='all')
25
+ log_update(f"\t\tnew # columns: {len(df.columns)}")
26
+ log_update(f"\t\tcolumn names: {','.join(list(df.columns))}")
27
+
28
+ return df
29
+
30
+ def check_columns_for_listlike(df: pd.DataFrame, cols_of_interest: list, delimiters: set):
31
+ """
32
+ Checks if a column contains any listlike items
33
+
34
+ Args:
35
+ df (pd.DataFrame): DataFrame to be investigated
36
+ cols_of_interest (list): columns in df to be investigated for list-containing potential
37
+ delimiters (set): set of potential delimiting strings to search for. A column with any of these strings is considered listlike.
38
+
39
+ Returns:
40
+ dict: dictionary containing a set {} of all delimiters found in each column
41
+ e.g., { 'col1': {',',';'},
42
+ 'col2': {'|'} }
43
+ """
44
+ # return the delimiters/listlike things found for each column
45
+ return_dict = {}
46
+
47
+ log_update("\tchecking if any of our columns of interest look listlike (contain list objects or delimiters)...")
48
+ for col in cols_of_interest:
49
+ unique_col = list(df[col].value_counts().index)
50
+ listlike = any([check_item_for_listlike(x, delimiters) for x in unique_col])
51
+
52
+ if listlike:
53
+ found_delims = df[col].apply(lambda x: check_item_for_listlike(x, delimiters)).value_counts().reset_index()['index'].to_list()
54
+ unique_found_delims = set()
55
+ for x in found_delims:
56
+ unique_found_delims = unique_found_delims.union(x)
57
+
58
+ return_dict[col] = unique_found_delims
59
+ else:
60
+ return_dict[col] = False
61
+
62
+ # display the return dict
63
+ log_update(f"\t\tcolumn name: {col}\tlistlike: {return_dict[col]}")
64
+
65
+ return return_dict
66
+
67
+ def check_item_for_listlike(x, delimiters: set):
68
+ """
69
+ Checks if a column looks like it contains a list of items, rather than an inidvidual item, based on string delimiters.
70
+
71
+ Args:
72
+ x: the item to check. Any dtype.
73
+ delimiters: a set of delimiters to check for. e.g., {',', ';', '|', '\t', ' ', ':', '-', '/', '\\', '\n'}
74
+
75
+ Returns:
76
+ If x is a string: the set (may be empty) of delimiters contained in the string
77
+ If x is not a string: the dtype of x
78
+ """
79
+ if isinstance(x, str):
80
+ return find_delimiters(x, delimiters)
81
+ else:
82
+ if x is None:
83
+ # if it's None, it's not listlike, it's just empty. return {} because it has no delimiters.
84
+ return {}
85
+ if type(x)==float:
86
+ # if it's nan, it's not listlike, it's just empty. return {} because it has no delimiters.
87
+ if np.isnan(x):
88
+ return {}
89
+ return type(x)
90
+
91
+ def find_delimiters(seq: str, delimiters: set) -> set:
92
+ """
93
+ Find and return a set of delimiters in a sequence. Helper mtehod for check_item_for_listlike.
94
+
95
+ Args:
96
+ seq (str): The sequence you wish to search for invalid characters.
97
+ delimiters (set): a set of delimiters to check for. e.g., {',', ';', '|', '\t', ' ', ':', '-', '/', '\\', '\n'}
98
+
99
+ Returns:
100
+ set: A set of characters in the sequence that are not in the set of valid characters.
101
+ """
102
+ unique_chars = set(seq) # set of all characters in the sequence; unique_chars = {A, C} for protein="AAACCC"
103
+ overlap = delimiters.intersection(unique_chars)
104
+
105
+ if len(overlap)==0:
106
+ return {}
107
+ else:
108
+ return overlap
109
+
110
+ def find_invalid_chars(seq: str, valid_chars: set) -> set:
111
+ """
112
+ Find and return a set of invalid characters in a sequence.
113
+
114
+ Args:
115
+ seq (str): The sequence you wish to search for invalid characters.
116
+ valid_chars (set): A set of valid characters.
117
+
118
+ Returns:
119
+ set: A set of characters in the sequence that are not in the set of valid characters.
120
+ """
121
+ unique_chars = set(seq) # set of all characters in the sequence; unique_chars = {A, C} for protein="AAACCC"
122
+
123
+ if unique_chars.issubset(valid_chars): # e.g. unique_chars = {A,C}, and {A,C} is a subset of valid_chars
124
+ return ''
125
+ else: # e.g. unique_chars = {A,X}. {A,X} is not a subset of valid_chars because X is not in valid_chars
126
+ return unique_chars.difference(valid_chars) # e.g. {A,X} - valid_chars = {X}
fuson_plm/utils/embedding.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ from transformers import EsmModel, AutoTokenizer
4
+ from transformers import T5Tokenizer, T5EncoderModel
5
+ import pickle
6
+ import logging
7
+ from fuson_plm.utils.logging import log_update
8
+
9
+
10
+ def redump_pickle_dictionary(pickle_path):
11
+ """
12
+ Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+'
13
+ """
14
+ entries = {}
15
+ # Load one by one
16
+ with open(pickle_path, 'rb') as f:
17
+ while True:
18
+ try:
19
+ entry = pickle.load(f)
20
+ entries.update(entry)
21
+ except EOFError:
22
+ break # End of file reached
23
+ except Exception as e:
24
+ print(f"An error occurred: {e}")
25
+ break
26
+ # Redump
27
+ with open(pickle_path, 'wb') as f:
28
+ pickle.dump(entries, f)
29
+
30
+ def load_esm2_type(esm_type, device=None):
31
+ """
32
+ Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
33
+ """
34
+ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
35
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
36
+
37
+ if device is None:
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ print(f"Using device: {device}")
40
+
41
+ model = EsmModel.from_pretrained(f"facebook/{esm_type}")
42
+ tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
43
+
44
+ model.to(device)
45
+ model.eval() # disables dropout for deterministic results
46
+
47
+ return model, tokenizer, device
48
+
49
+ def load_prott5():
50
+ # Initialize tokenizer and model
51
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
52
+ tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
53
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
54
+ if device == torch.device('cpu'):
55
+ model.to(torch.float32)
56
+ model.to(device)
57
+ return model, tokenizer, device
58
+
59
+ def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
60
+ """
61
+ Compute ESM embeddings.
62
+
63
+ Args:
64
+ model
65
+ tokenizer
66
+ sequences
67
+ device
68
+ average: if True, the average embeddings will be taken and returned
69
+ savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle
70
+ """
71
+ # Correct save path to pickle if necessary
72
+ if savepath is not None:
73
+ if savepath[-4::] != '.pkl': savepath += '.pkl'
74
+
75
+ # If no max length was passed, just set it to the maximum in the dataset
76
+ max_seq_len = max([len(s) for s in sequences])
77
+ if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
78
+
79
+ # Initialize an empty dict to store the ESM embeddings
80
+ embedding_dict = {}
81
+ # Iterate through the seqs
82
+ for i in range(len(sequences)):
83
+ sequence = sequences[i]
84
+ # Get the embeddings
85
+ with torch.no_grad():
86
+ inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length)
87
+ inputs = {k:v.to(device) for k, v in inputs.items()}
88
+
89
+ outputs = model(**inputs)
90
+ embedding = outputs.last_hidden_state
91
+
92
+ # remove extra dimension
93
+ embedding = embedding.squeeze(0)
94
+ # remove BOS and EOS tokens
95
+ embedding = embedding[1:-1, :]
96
+
97
+ # Convert embeddings to numpy array (if needed)
98
+ embedding = embedding.cpu().numpy()
99
+
100
+ # Average (if necessary)
101
+ if average:
102
+ embedding = embedding.mean(0)
103
+
104
+ # Add to dictionary
105
+ embedding_dict[sequence] = embedding
106
+
107
+ # Save individual embedding (if necessary)
108
+ if not(savepath is None) and not(save_at_end):
109
+ with open(savepath, 'ab+') as f:
110
+ d = {sequence: embedding}
111
+ pickle.dump(d, f)
112
+
113
+ # Print update (if necessary)
114
+ if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
115
+
116
+ # Dump all at once at the end (if necessary)
117
+ if not(savepath is None):
118
+ # If saving for the first time, just dump it
119
+ if save_at_end:
120
+ with open(savepath, 'wb') as f:
121
+ pickle.dump(embedding_dict, f)
122
+ # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
123
+ else:
124
+ redump_pickle_dictionary(savepath)
125
+
126
+ # Return the dictionary
127
+ return embedding_dict
128
+
129
+ def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
130
+ # Correct save path to pickle if necessary
131
+ if savepath is not None:
132
+ if savepath[-4::] != '.pkl': savepath += '.pkl'
133
+
134
+ # If no max length was passed, just set it to the maximum in the dataset
135
+ max_seq_len = max([len(s) for s in sequences])
136
+ if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
137
+
138
+ # the ProtT5 tokenizer requires that there are spaces between residues
139
+ spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer
140
+
141
+ # Store embeddings here
142
+ embedding_dict = {} # store embeddings here
143
+
144
+ for i in range(0, len(spaced_sequences)):
145
+ spaced_sequence = spaced_sequences[i] # get current sequence
146
+ seq = spaced_sequence.replace(" ", "")
147
+
148
+ with torch.no_grad():
149
+ inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1
150
+ inputs = {k:v.to(device) for k, v in inputs.items()}
151
+
152
+ # Pass through the model with no gradient to get embeddings
153
+ embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
154
+
155
+ # Process the embedding
156
+ seq_length = len(seq) # length of the sequence is after you remove spaces
157
+ embedding = embedding_repr.last_hidden_state.squeeze(0) # remove batch dimension
158
+ embedding = embedding[0:-1] # remove EOS token (there is no BOS token)
159
+ embedding = embedding.cpu().numpy() # put on CPU and numpy
160
+ embedding_log = f"\tembedding shape: {embedding.shape}"
161
+ # MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length
162
+ assert embedding.shape[1] == 1024
163
+ assert embedding.shape[0] == seq_length
164
+
165
+ # Average (if necessary)
166
+ if average:
167
+ dim_before = embedding.shape
168
+ embedding = embedding.mean(0)
169
+ embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}"
170
+
171
+ # Add the embedding to the dictionary
172
+ embedding_dict[seq] = embedding
173
+
174
+ # Save individual embedding (if necessary)
175
+ if not(savepath is None) and not(save_at_end):
176
+ with open(savepath, 'ab+') as f:
177
+ d = {seq: embedding}
178
+ pickle.dump(d, f)
179
+
180
+ if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}")
181
+
182
+ # Dump all at once at the end (if necessary)
183
+ if not(savepath is None):
184
+ # If saving for the first time, just dump it
185
+ if save_at_end:
186
+ with open(savepath, 'wb') as f:
187
+ pickle.dump(embedding_dict, f)
188
+ # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
189
+ else:
190
+ redump_pickle_dictionary(savepath)
191
+
192
+ # Return the dictionary
193
+ return embedding_dict
fuson_plm/utils/logging.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from contextlib import contextmanager
3
+ import sys
4
+ import pytz
5
+ import os
6
+
7
+ class CustomParams:
8
+ """
9
+ Class for custom parameters where dictionary elements can be accessed as attributes
10
+ """
11
+ def __init__(self, **kwargs):
12
+ self.__dict__.update(kwargs)
13
+
14
+ def print_config(self,indent=''):
15
+ for attr, value in self.__dict__.items():
16
+ print(f"{indent}{attr}: {value}")
17
+
18
+ def log_update(text: str):
19
+ """
20
+ Logs input text to an output file
21
+
22
+ Args:
23
+ text (str): the text to be logged
24
+ """
25
+ print(text) # log_update the text
26
+ sys.stdout.flush() # flush to automatically update the output file
27
+
28
+ @contextmanager
29
+ def open_logfile(log_path,mode='w'):
30
+ """
31
+ Open log-file for real-time logging of the most important updates
32
+ """
33
+ log_file = open(log_path, mode) # open
34
+ original_stdout = sys.stdout # save original stdout
35
+ sys.stdout = log_file # redirect stdout to log_file
36
+ try:
37
+ yield log_file
38
+ finally:
39
+ sys.stdout = original_stdout
40
+ log_file.close()
41
+
42
+ @contextmanager
43
+ def open_errfile(log_path,mode='w'):
44
+ """
45
+ Redirects stderr (error messages) to a separate log file.
46
+ """
47
+ log_file = open(log_path, mode) # open the error log file for writing
48
+ original_stderr = sys.stderr # save original stderr
49
+ sys.stderr = log_file # redirect stderr to log_file
50
+ try:
51
+ yield log_file
52
+ finally:
53
+ sys.stderr = original_stderr # restore original stderr
54
+ log_file.close() # close the error log file
55
+
56
+ def print_configpy(module):
57
+ """
58
+ Prints all the configurations in a config.py file
59
+ """
60
+ log_update("All configurations:")
61
+ # Iterate over attributes
62
+ for attribute in dir(module):
63
+ # Filter out built-in attributes and methods
64
+ if not attribute.startswith("__"):
65
+ value = getattr(module, attribute)
66
+ log_update(f"\t{attribute}: {value}")
67
+
68
+ def get_local_time(timezone_str='US/Eastern'):
69
+ """
70
+ Get current time in the specified timezone.
71
+
72
+ Args:
73
+ timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
74
+
75
+ Returns:
76
+ str: The formatted current time in the specified timezone.
77
+ """
78
+ try:
79
+ timezone = pytz.timezone(timezone_str)
80
+ except pytz.UnknownTimeZoneError:
81
+ return f"Unknown timezone: {timezone_str}"
82
+
83
+ current_datetime = datetime.now(pytz.utc).astimezone(timezone)
84
+ return current_datetime.strftime('%m-%d-%Y-%H:%M:%S')
85
+
86
+ def get_local_date_yr(timezone_str='US/Eastern'):
87
+ """
88
+ Get current time in the specified timezone.
89
+
90
+ Args:
91
+ timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
92
+
93
+ Returns:
94
+ str: The formatted current time in the specified timezone.
95
+ """
96
+ try:
97
+ timezone = pytz.timezone(timezone_str)
98
+ except pytz.UnknownTimeZoneError:
99
+ return f"Unknown timezone: {timezone_str}"
100
+
101
+ current_datetime = datetime.now(pytz.utc).astimezone(timezone)
102
+ return current_datetime.strftime('%m_%d_%Y')
103
+
104
+ def find_fuson_plm_directory():
105
+ """
106
+ Constructs a path backwards to fuson_plm directory so we don't have to use absolute paths (helps for docker containers)
107
+ """
108
+ current_dir = os.path.abspath(os.getcwd())
109
+
110
+ while True:
111
+ if 'fuson_plm' in os.listdir(current_dir):
112
+ return os.path.join(current_dir, 'fuson_plm')
113
+ parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
114
+ if parent_dir == current_dir: # If we've reached the root directory
115
+ raise FileNotFoundError("fuson_plm directory not found.")
116
+ current_dir = parent_dir
fuson_plm/utils/splitting.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.model_selection import train_test_split
3
+ from fuson_plm.utils.logging import log_update
4
+
5
+ def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20):
6
+ # cluster with random state fixed for reproducible results
7
+ log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})")
8
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
9
+
10
+ # add benchmark representatives back to X_test
11
+ log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
12
+ X_test += benchmark_cluster_reps
13
+
14
+ # assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
15
+ assert len(X_train)==len(set(X_train))
16
+ assert len(X_test)==len(set(X_test))
17
+
18
+ return {
19
+ 'X_train': X_train,
20
+ 'X_test': X_test
21
+ }
22
+
23
+ def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
24
+ # cluster with random state fixed for reproducible results
25
+ log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
26
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1)
27
+ log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
28
+ X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2)
29
+
30
+ # add benchmark representatives back to X_test
31
+ log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
32
+ X_test += benchmark_cluster_reps
33
+
34
+ # assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
35
+ assert len(X_train)==len(set(X_train))
36
+ assert len(X_val)==len(set(X_val))
37
+ assert len(X_test)==len(set(X_test))
38
+
39
+ return {
40
+ 'X_train': X_train,
41
+ 'X_val': X_val,
42
+ 'X_test': X_test
43
+ }
44
+
45
+ def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
46
+ """"
47
+ Cluster-splitting method amenable to either train-test or train-val-test.
48
+ For train-val-test, there are two splits.
49
+ """
50
+ log_update("\nPerforming splits...")
51
+ # Approx. 80/10/10 split
52
+ X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)] # X, for splitting, does NOT include benchmark reps. We'll add these clusters to test.
53
+ y = [0]*len(X) # y is a dummy array here; there are no values.
54
+
55
+ split_dict = None
56
+ if val_set:
57
+ split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
58
+ random_state_1 = random_state_1, random_state_2 = random_state_2,
59
+ test_size_1 = test_size_1, test_size_2 = test_size_2)
60
+ else:
61
+ split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
62
+ random_state = random_state_1,
63
+ test_size = test_size_1)
64
+
65
+ return split_dict
66
+
67
+ def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None):
68
+ """
69
+ Args:
70
+ train_clusters (pd.DataFrame):
71
+ val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set)
72
+ test_clusters (pd.DataFrame):
73
+ """
74
+
75
+ # Make grouped versions of these DataFrames for size analysis
76
+ train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
77
+ if val_clusters is not None:
78
+ val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
79
+ if test_clusters is not None:
80
+ test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
81
+
82
+ # Calculate stats - clusters
83
+ n_train_clusters = len(train_clustersgb)
84
+ n_val_clusters, n_test_clusters = 0, 0
85
+ if val_clusters is not None:
86
+ n_val_clusters = len(val_clustersgb)
87
+ if test_clusters is not None:
88
+ n_test_clusters = len(test_clustersgb)
89
+ n_clusters = n_train_clusters + n_val_clusters + n_test_clusters
90
+
91
+ assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb)
92
+ if val_clusters is not None:
93
+ assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb)
94
+ if test_clusters is not None:
95
+ assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb)
96
+
97
+ train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2)
98
+ if val_clusters is not None:
99
+ val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2)
100
+ if test_clusters is not None:
101
+ test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2)
102
+
103
+ # Calculate stats - proteins
104
+ n_train_proteins = len(train_clusters)
105
+ n_val_proteins, n_test_proteins = 0, 0
106
+ if val_clusters is not None:
107
+ n_val_proteins = len(val_clusters)
108
+ if test_clusters is not None:
109
+ n_test_proteins = len(test_clusters)
110
+ n_proteins = n_train_proteins + n_val_proteins + n_test_proteins
111
+
112
+ assert len(train_clusters) == sum(train_clustersgb['member count'])
113
+ if val_clusters is not None:
114
+ assert len(val_clusters) == sum(val_clustersgb['member count'])
115
+ if test_clusters is not None:
116
+ assert len(test_clusters) == sum(test_clustersgb['member count'])
117
+
118
+ train_protein_pcnt = round(100*n_train_proteins/n_proteins,2)
119
+ if val_clusters is not None:
120
+ val_protein_pcnt = round(100*n_val_proteins/n_proteins,2)
121
+ if test_clusters is not None:
122
+ test_protein_pcnt = round(100*n_test_proteins/n_proteins,2)
123
+
124
+ # Print results
125
+ log_update("\nCluster breakdown...")
126
+ log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}")
127
+ log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)")
128
+ if val_clusters is not None:
129
+ log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)")
130
+ if test_clusters is not None:
131
+ log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)")
132
+
133
+ # Check for overlap in both sequence ID and sequence actual
134
+ train_protein_ids = set(train_clusters['member seq_id'])
135
+ train_protein_seqs = set(train_clusters['member seq'])
136
+ if val_clusters is not None:
137
+ val_protein_ids = set(val_clusters['member seq_id'])
138
+ val_protein_seqs = set(val_clusters['member seq'])
139
+ if test_clusters is not None:
140
+ test_protein_ids = set(test_clusters['member seq_id'])
141
+ test_protein_seqs = set(test_clusters['member seq'])
142
+
143
+ # Print results
144
+ log_update("\nChecking for overlap...")
145
+ if (val_clusters is not None) and (test_clusters is not None):
146
+ log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}")
147
+ log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}")
148
+ if (val_clusters is not None) and (test_clusters is None):
149
+ log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}")
150
+ log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}")
151
+ if (val_clusters is None) and (test_clusters is not None):
152
+ log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}")
153
+ log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}")
154
+
155
+ # Assert no sequence overlap
156
+ if val_clusters is not None:
157
+ assert len(train_protein_seqs.intersection(val_protein_seqs))==0
158
+ if test_clusters is not None:
159
+ assert len(train_protein_seqs.intersection(test_protein_seqs))==0
160
+ if (val_clusters is not None) and (test_clusters is not None):
161
+ assert len(val_protein_seqs.intersection(test_protein_seqs))==0
162
+
163
+ # Finally, check that there are only benchmark sequences in test - if there are benchmark sequences
164
+ if not(benchmark_sequences is None):
165
+ bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
166
+ bench_in_val, bench_in_test = 0, 0
167
+ if val_clusters is not None:
168
+ bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
169
+ if test_clusters is not None:
170
+ bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
171
+
172
+ # Assert this
173
+ log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...")
174
+ log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}")
175
+ log_update(f"\tBenchmark sequences in train: {bench_in_train}")
176
+ if val_clusters is not None:
177
+ log_update(f"\tBenchmark sequences in val: {bench_in_val}")
178
+ if test_clusters is not None:
179
+ log_update(f"\tBenchmark sequences in test: {bench_in_test}")
180
+ assert bench_in_train == bench_in_val == 0
181
+ assert bench_in_test == len(benchmark_sequences)
182
+
183
+ def check_class_distributions(train_df, val_df, test_df, class_col='class'):
184
+ """
185
+ Checks class distributions within train, val, and test sets.
186
+ Expects input dataframes to have 'sequence' column and 'class' column
187
+ """
188
+ train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'})
189
+ train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100
190
+ if val_df is not None:
191
+ val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'})
192
+ val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100
193
+ test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'})
194
+ test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100
195
+ # concatenate so I can see them next to each other
196
+ if val_df is not None:
197
+ compare = pd.concat([train_vc, val_vc, test_vc], axis=1)
198
+ compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x))
199
+ compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x))
200
+ else:
201
+ compare = pd.concat([train_vc, test_vc], axis=1)
202
+ compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x))
203
+
204
+ compare_str = compare.to_string(index=False)
205
+ compare_str = "\t" + compare_str.replace("\n","\n\t")
206
+ log_update(f"\nClass distribution:\n{compare_str}")
fuson_plm/utils/visualizing.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.font_manager as fm
3
+ from matplotlib.font_manager import FontProperties
4
+ from scipy.stats import entropy
5
+ from sklearn.manifold import TSNE
6
+ import pickle
7
+ import pandas as pd
8
+ import os
9
+ import numpy as np
10
+ from fuson_plm.utils.logging import log_update, find_fuson_plm_directory
11
+
12
+ def set_font():
13
+ # Load and set the font
14
+ fuson_plm_dir = find_fuson_plm_directory()
15
+
16
+ # Paths for regular, bold, italic fonts
17
+ regular_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Regular.ttf')
18
+ bold_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Bold.ttf')
19
+ italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Italic.ttf')
20
+ bold_italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-BoldItalic.ttf')
21
+
22
+ # Load the font properties
23
+ regular_font = FontProperties(fname=regular_font_path)
24
+ bold_font = FontProperties(fname=bold_font_path)
25
+ italic_font = FontProperties(fname=italic_font_path)
26
+ bold_italic_font = FontProperties(fname=bold_italic_font_path)
27
+
28
+ # Add the fonts to the font manager
29
+ fm.fontManager.addfont(regular_font_path)
30
+ fm.fontManager.addfont(bold_font_path)
31
+ fm.fontManager.addfont(italic_font_path)
32
+ fm.fontManager.addfont(bold_italic_font_path)
33
+
34
+ # Set the font family globally to Ubuntu
35
+ plt.rcParams['font.family'] = regular_font.get_name()
36
+
37
+ # Set the fonts for math text (like for labels) to use the loaded Ubuntu fonts
38
+ plt.rcParams['mathtext.fontset'] = 'custom'
39
+ plt.rcParams['mathtext.rm'] = regular_font.get_name()
40
+ plt.rcParams['mathtext.it'] = f'{italic_font.get_name()}'
41
+ plt.rcParams['mathtext.bf'] = f'{bold_font.get_name()}'
42
+
43
+ global default_color_map
44
+ default_color_map = {
45
+ 'train': '#0072B2',
46
+ 'val': '#009E73',
47
+ 'test': '#E69F00'
48
+ }
49
+
50
+ def get_avg_embeddings_for_tsne(train_sequences=None, val_sequences=None, test_sequences=None, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'):
51
+ if train_sequences is None: train_sequences = []
52
+ if val_sequences is None: val_sequences = []
53
+ if test_sequences is None: test_sequences = []
54
+
55
+ embeddings = {}
56
+
57
+ try:
58
+ with open(embedding_path, 'rb') as f:
59
+ embeddings = pickle.load(f)
60
+
61
+ train_embeddings = [v for k, v in embeddings.items() if k in train_sequences]
62
+ val_embeddings = [v for k, v in embeddings.items() if k in val_sequences]
63
+ test_embeddings = [v for k, v in embeddings.items() if k in test_sequences]
64
+
65
+ return train_embeddings, val_embeddings, test_embeddings
66
+ except:
67
+ print("could not open embeddings")
68
+
69
+
70
+ def calculate_aa_composition(sequences):
71
+ composition = {}
72
+ total_length = sum([len(seq) for seq in sequences])
73
+
74
+ for seq in sequences:
75
+ for aa in seq:
76
+ if aa in composition:
77
+ composition[aa] += 1
78
+ else:
79
+ composition[aa] = 1
80
+
81
+ # Convert counts to relative frequency
82
+ for aa in composition:
83
+ composition[aa] /= total_length
84
+
85
+ return composition
86
+
87
+ def calculate_shannon_entropy(sequence):
88
+ """
89
+ Calculate the Shannon entropy for a given sequence.
90
+
91
+ Args:
92
+ sequence (str): A sequence of characters (e.g., amino acids or nucleotides).
93
+
94
+ Returns:
95
+ float: Shannon entropy value.
96
+ """
97
+ bases = set(sequence)
98
+ counts = [sequence.count(base) for base in bases]
99
+ return entropy(counts, base=2)
100
+
101
+ def visualize_splits_hist(train_lengths=None, val_lengths=None, test_lengths=None, colormap=None, savepath=f'splits/length_distributions.png', axes=None):
102
+ """
103
+ Works to plot train, val, test; train, val; or train, test
104
+ """
105
+ set_font()
106
+ if colormap is None: colormap=default_color_map
107
+
108
+ log_update('\nMaking histogram of length distributions')
109
+
110
+ # Get index for test plot
111
+ val_plot_index, test_plot_index, total_plots = 1, 2, 3
112
+ if val_lengths is None:
113
+ val_plot_index = None
114
+ test_plot_index-= 1
115
+ total_plots-=1
116
+ if test_lengths is None:
117
+ test_plot_index = None
118
+ total_plots-=1
119
+
120
+ # Create a figure and axes with 1 row and 3 columns
121
+ fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
122
+
123
+ # Set axes list
124
+ axes_list = [axes_individual] if axes is None else [axes_individual, axes]
125
+
126
+ # Unpack the labels and titles
127
+ xlabel, ylabel = ['Sequence Length (AA)', 'Frequency']
128
+
129
+ for cur_axes in axes_list:
130
+ # Plot the first histogram
131
+ cur_axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train'])
132
+ cur_axes[0].set_xlabel(xlabel)
133
+ cur_axes[0].set_ylabel(ylabel)
134
+ cur_axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})')
135
+ cur_axes[0].grid(True)
136
+ cur_axes[0].set_axisbelow(True)
137
+
138
+ # Plot the second histogram
139
+ if not(val_plot_index is None):
140
+ cur_axes[val_plot_index].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val'])
141
+ cur_axes[val_plot_index].set_xlabel(xlabel)
142
+ cur_axes[val_plot_index].set_ylabel(ylabel)
143
+ cur_axes[val_plot_index].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})')
144
+ cur_axes[val_plot_index].grid(True)
145
+ cur_axes[val_plot_index].set_axisbelow(True)
146
+
147
+ # Plot the third histogram
148
+ if not(test_plot_index is None):
149
+ cur_axes[test_plot_index].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test'])
150
+ cur_axes[test_plot_index].set_xlabel(xlabel)
151
+ cur_axes[test_plot_index].set_ylabel(ylabel)
152
+ cur_axes[test_plot_index].set_title(f'Test Set Length Distribution (n={len(test_lengths)})')
153
+ cur_axes[test_plot_index].grid(True)
154
+ cur_axes[test_plot_index].set_axisbelow(True)
155
+
156
+ # Adjust layout
157
+ fig_individual.set_tight_layout(True)
158
+
159
+ # Save the figure
160
+ fig_individual.savefig(savepath)
161
+ log_update(f"\tSaved figure to {savepath}")
162
+
163
+ def visualize_splits_scatter(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, colormap=None, savepath='splits/scatterplot.png', axes=None):
164
+ set_font()
165
+ if colormap is None: colormap=default_color_map
166
+
167
+ # Create a figure and axes with 1 row and 3 columns
168
+ fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
169
+
170
+ # Set axes list
171
+ axes_list = [axes_individual] if axes is None else [axes_individual, axes]
172
+
173
+ log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val")
174
+ # Make grouped versions of these DataFrames for size analysis
175
+ train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
176
+ if not(val_clusters is None):
177
+ val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
178
+ if not(test_clusters is None):
179
+ test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
180
+ # Isolate benchmark-containing clusters so their contribution can be plotted separately
181
+ total_test_proteins = sum(test_clustersgb['member count'])
182
+ if not(benchmark_cluster_reps is None):
183
+ test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps)
184
+ benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True)
185
+ test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True)
186
+
187
+ # Convert them to value counts
188
+ train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
189
+ if not(val_clusters is None):
190
+ val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
191
+ if not(test_clusters is None):
192
+ test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
193
+ if not(benchmark_cluster_reps is None):
194
+ benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
195
+
196
+ # Get the percentage of each dataset that's made of each cluster size
197
+ train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins
198
+ train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins'])
199
+ if not(val_clusters is None):
200
+ val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters']
201
+ val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins'])
202
+ if not(test_clusters is None):
203
+ test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters']
204
+ test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins
205
+ if not(benchmark_cluster_reps is None):
206
+ benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters']
207
+ benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins
208
+
209
+ # Specially mark the benchmark clusters because these can't be reallocated
210
+ for ax in axes_list:
211
+ ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train')
212
+ if not(val_clusters is None):
213
+ ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val')
214
+ if not(test_clusters is None):
215
+ ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test')
216
+ if not(benchmark_cluster_reps is None):
217
+ ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'],
218
+ marker='o',
219
+ linestyle='None',
220
+ markerfacecolor=colormap['test'], # fill same as test
221
+ markeredgecolor='black', # outline black
222
+ markeredgewidth=1.5,
223
+ label='benchmark'
224
+ )
225
+ ax.set(ylabel='Percentage of Proteins in Dataset',xlabel='cluster_size')
226
+ ax.legend()
227
+
228
+ # save the figure
229
+ fig_individual.set_tight_layout(True)
230
+ fig_individual.savefig(savepath)
231
+ log_update(f"\tSaved figure to {savepath}")
232
+
233
+
234
+ def visualize_splits_tsne(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='splits/tsne_plot.png',axes=None):
235
+ set_font()
236
+
237
+ if colormap is None: colormap=default_color_map
238
+
239
+ """
240
+ Generate a t-SNE plot of embeddings for train, test, and validation.
241
+ """
242
+ log_update('\nMaking t-SNE plot of train, val, and test embeddings')
243
+ # Create a figure and axes with 1 row and 3 columns
244
+ fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
245
+
246
+ # Set axes list
247
+ axes_list = [axes_individual] if axes is None else [axes_individual, axes]
248
+
249
+ # Combine the embeddings into one array
250
+ train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences=train_sequences,
251
+ val_sequences=val_sequences,
252
+ test_sequences=test_sequences, embedding_path=embedding_path)
253
+ if not(val_embeddings is None) and not(test_embeddings is None):
254
+ embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings])
255
+ labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings)
256
+ if not(val_embeddings is None) and (test_embeddings is None):
257
+ embeddings = np.concatenate([train_embeddings, val_embeddings])
258
+ labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings)
259
+ if (val_embeddings is None) and not(test_embeddings is None):
260
+ embeddings = np.concatenate([train_embeddings, test_embeddings])
261
+ labels = ['train'] * len(train_embeddings) + ['test'] * len(test_embeddings)
262
+
263
+ # Perform t-SNE
264
+ tsne = TSNE(n_components=2, random_state=42)
265
+ tsne_results = tsne.fit_transform(embeddings)
266
+
267
+ # Convert t-SNE results into a DataFrame
268
+ tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2'])
269
+ tsne_df['label'] = labels
270
+
271
+ for ax in axes_list:
272
+ # Scatter plot for each set
273
+ for label, color in colormap.items():
274
+ subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True)
275
+ ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6)
276
+
277
+ ax.set_title(f't-SNE of {esm_type} Embeddings')
278
+ ax.set_xlabel('t-SNE Dimension 1')
279
+ ax.set_ylabel('t-SNE Dimension 2')
280
+ ax.legend()
281
+ ax.grid(True)
282
+
283
+ # Save the figure if savepath is provided
284
+ fig_individual.set_tight_layout(True)
285
+ fig_individual.savefig(savepath)
286
+ log_update(f"\tSaved figure to {savepath}")
287
+
288
+ def visualize_splits_shannon_entropy(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/shannon_entropy_plot.png',axes=None):
289
+ set_font()
290
+ """
291
+ Generate Shannon entropy plots for train, validation, and test sets.
292
+ """
293
+ # Get index for test plot
294
+ val_plot_index, test_plot_index, total_plots = 1, 2, 3
295
+ if val_sequences is None:
296
+ val_plot_index = None
297
+ test_plot_index-= 1
298
+ total_plots-=1
299
+ if test_sequences is None:
300
+ test_plot_index = None
301
+ total_plots-=1
302
+
303
+ if colormap is None: colormap=default_color_map
304
+ # Create a figure and axes with 1 row and 3 columns
305
+ fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
306
+
307
+ # Set axes list
308
+ axes_list = [axes_individual] if axes is None else [axes_individual, axes]
309
+
310
+ log_update('\nMaking histogram of Shannon Entropy distributions')
311
+ train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences]
312
+ if not(val_plot_index is None):
313
+ val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences]
314
+ if not(test_plot_index is None):
315
+ test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences]
316
+
317
+ for ax in axes_list:
318
+ ax[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train'])
319
+ ax[0].set_title(f'Train Set (n={len(train_entropy)})')
320
+ ax[0].set_xlabel('Shannon Entropy')
321
+ ax[0].set_ylabel('Frequency')
322
+ ax[0].grid(True)
323
+ ax[0].set_axisbelow(True)
324
+
325
+ if not(val_plot_index is None):
326
+ ax[val_plot_index].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val'])
327
+ ax[val_plot_index].set_title(f'Validation Set (n={len(val_entropy)})')
328
+ ax[val_plot_index].set_xlabel('Shannon Entropy')
329
+ ax[val_plot_index].grid(True)
330
+ ax[val_plot_index].set_axisbelow(True)
331
+
332
+ if not(test_plot_index is None):
333
+ ax[test_plot_index].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test'])
334
+ ax[test_plot_index].set_title(f'Test Set (n={len(test_entropy)})')
335
+ ax[test_plot_index].set_xlabel('Shannon Entropy')
336
+ ax[test_plot_index].grid(True)
337
+ ax[test_plot_index].set_axisbelow(True)
338
+
339
+ fig_individual.set_tight_layout(True)
340
+ fig_individual.savefig(savepath)
341
+ log_update(f"\tSaved figure to {savepath}")
342
+
343
+ def visualize_splits_aa_composition(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/aa_comp.png',axes=None):
344
+ set_font()
345
+ if colormap is None: colormap=default_color_map
346
+
347
+ # Create a figure and axes with 1 row and 3 columns
348
+ fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
349
+
350
+ # Set axes list
351
+ axes_list = [axes_individual] if axes is None else [axes_individual, axes]
352
+
353
+ log_update('\nMaking bar plot of AA composition across each set')
354
+ train_comp = calculate_aa_composition(train_sequences)
355
+ if not(val_sequences is None):
356
+ val_comp = calculate_aa_composition(val_sequences)
357
+ if not(test_sequences is None):
358
+ test_comp = calculate_aa_composition(test_sequences)
359
+
360
+ # Create DataFrame
361
+ if not(val_sequences is None) and not(test_sequences is None):
362
+ comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T
363
+ if not(val_sequences is None) and (test_sequences is None):
364
+ comp_df = pd.DataFrame([train_comp, val_comp], index=['train', 'val']).T
365
+ if (val_sequences is None) and not(test_sequences is None):
366
+ comp_df = pd.DataFrame([train_comp, test_comp], index=['train', 'test']).T
367
+ colors = [colormap[col] for col in comp_df.columns]
368
+
369
+ # Plotting
370
+ for ax in axes_list:
371
+ comp_df.plot(kind='bar', color=colors, ax=ax)
372
+ ax.set_title('Amino Acid Composition Across Datasets')
373
+ ax.set_xlabel('Amino Acid')
374
+ ax.set_ylabel('Relative Frequency')
375
+
376
+ fig_individual.set_tight_layout(True)
377
+ fig_individual.savefig(savepath)
378
+ log_update(f"\tSaved figure to {savepath}")
379
+
380
+ ### Outer methods for visualizing splits
381
+ def visualize_splits(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None):
382
+ colormap = {
383
+ 'train': train_color,
384
+ 'val': val_color,
385
+ 'test': test_color
386
+ }
387
+ valid_entry = False
388
+ # Add columns for plotting
389
+ if not(train_clusters is None) and not(val_clusters is None) and not(test_clusters is None):
390
+ visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters,benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
391
+ valid_entry=True
392
+ if not(train_clusters is None) and (val_clusters is None) and not(test_clusters is None):
393
+ visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
394
+ valid_entry=True
395
+ if not(train_clusters is None) and not(val_clusters is None) and (test_clusters is None):
396
+ visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
397
+ valid_entry=True
398
+
399
+ if not(valid_entry): raise Exception("Must pass train and at least one of val or test")
400
+
401
+ def visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
402
+ if colormap is None: colormap=default_color_map
403
+ # Add length column
404
+ train_clusters['member length'] = train_clusters['member seq'].str.len()
405
+ val_clusters['member length'] = val_clusters['member seq'].str.len()
406
+ test_clusters['member length'] = test_clusters['member seq'].str.len()
407
+
408
+ # Prepare lengths and seqs for plotting
409
+ train_lengths = train_clusters['member length'].tolist()
410
+ val_lengths = val_clusters['member length'].tolist()
411
+ test_lengths = test_clusters['member length'].tolist()
412
+ train_sequences = train_clusters['member seq'].tolist()
413
+ val_sequences = val_clusters['member seq'].tolist()
414
+ test_sequences = test_clusters['member seq'].tolist()
415
+
416
+ # Create a combined figure with 3 rows and 3 columns
417
+ set_font()
418
+ fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18))
419
+
420
+ # Make the three visualization plots for saving TOGETHER
421
+ visualize_splits_hist(train_lengths=train_lengths,
422
+ val_lengths=val_lengths,
423
+ test_lengths=test_lengths,
424
+ colormap=colormap, axes=axs[0])
425
+ visualize_splits_shannon_entropy(train_sequences=train_sequences,
426
+ val_sequences=val_sequences,
427
+ test_sequences=test_sequences,
428
+ colormap=colormap,axes=axs[1])
429
+ visualize_splits_scatter(train_clusters=train_clusters,
430
+ val_clusters=val_clusters,
431
+ test_clusters=test_clusters,
432
+ benchmark_cluster_reps=benchmark_cluster_reps,
433
+ colormap=colormap, axes=axs[2, 0])
434
+ visualize_splits_aa_composition(train_sequences=train_sequences,
435
+ val_sequences=val_sequences,
436
+ test_sequences=test_sequences,
437
+ colormap=colormap, axes=axs[2, 1])
438
+ if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
439
+ visualize_splits_tsne(train_sequences=train_sequences,
440
+ val_sequences=val_sequences,
441
+ test_sequences=test_sequences,
442
+ colormap=colormap, axes=axs[2, 2])
443
+ else:
444
+ # Leave the last subplot blank
445
+ axs[2, 2].axis('off')
446
+
447
+ plt.tight_layout()
448
+ fig_combined.savefig('splits/combined_plot.png')
449
+ log_update(f"\nSaved combined figure to splits/combined_plot.png")
450
+
451
+ def visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
452
+ if colormap is None: colormap=default_color_map
453
+ # Add length column
454
+ train_clusters['member length'] = train_clusters['member seq'].str.len()
455
+ test_clusters['member length'] = test_clusters['member seq'].str.len()
456
+
457
+ # Prepare lengths and seqs for plotting
458
+ train_lengths = train_clusters['member length'].tolist()
459
+ test_lengths = test_clusters['member length'].tolist()
460
+ train_sequences = train_clusters['member seq'].tolist()
461
+ test_sequences = test_clusters['member seq'].tolist()
462
+
463
+ # Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
464
+ if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
465
+ set_font()
466
+ fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
467
+ visualize_splits_tsne(train_sequences=train_sequences,
468
+ val_sequences=None,
469
+ test_sequences=test_sequences,
470
+ colormap=colormap, axes=axs[3, 0])
471
+ axs[-1,1].axis('off')
472
+ else:
473
+ set_font()
474
+ fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
475
+
476
+ # Make the three visualization plots for saving TOGETHER
477
+ visualize_splits_hist(train_lengths=train_lengths,
478
+ val_lengths=None,
479
+ test_lengths=test_lengths,
480
+ colormap=colormap, axes=axs[0])
481
+ visualize_splits_shannon_entropy(train_sequences=train_sequences,
482
+ val_sequences=None,
483
+ test_sequences=test_sequences,
484
+ colormap=colormap,axes=axs[1])
485
+ visualize_splits_scatter(train_clusters=train_clusters,
486
+ val_clusters=None,
487
+ test_clusters=test_clusters,
488
+ benchmark_cluster_reps=benchmark_cluster_reps,
489
+ colormap=colormap, axes=axs[2, 0])
490
+ visualize_splits_aa_composition(train_sequences=train_sequences,
491
+ val_sequences=None,
492
+ test_sequences=test_sequences,
493
+ colormap=colormap, axes=axs[2, 1])
494
+
495
+ plt.tight_layout()
496
+ fig_combined.savefig('splits/combined_plot.png')
497
+ log_update(f"\nSaved combined figure to splits/combined_plot.png")
498
+
499
+ def visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
500
+ if colormap is None: colormap=default_color_map
501
+ # Add length column
502
+ train_clusters['member length'] = train_clusters['member seq'].str.len()
503
+ val_clusters['member length'] = val_clusters['member seq'].str.len()
504
+
505
+ # Prepare lengths and seqs for plotting
506
+ train_lengths = train_clusters['member length'].tolist()
507
+ val_lengths = val_clusters['member length'].tolist()
508
+ train_sequences = train_clusters['member seq'].tolist()
509
+ val_sequences = val_clusters['member seq'].tolist()
510
+
511
+ # Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
512
+ if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
513
+ set_font()
514
+ fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
515
+ visualize_splits_tsne(train_sequences=train_sequences,
516
+ val_sequences=val_sequences,
517
+ test_sequences=None,
518
+ colormap=colormap, axes=axs[3, 0])
519
+ axs[-1,1].axis('off')
520
+ else:
521
+ set_font()
522
+ fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
523
+
524
+ # Make the three visualization plots for saving TOGETHER
525
+ visualize_splits_hist(train_lengths=train_lengths,
526
+ val_lengths=val_lengths,
527
+ test_lengths=None,
528
+ colormap=colormap, axes=axs[0])
529
+ visualize_splits_shannon_entropy(train_sequences=train_sequences,
530
+ val_sequences=val_sequences,
531
+ test_sequences=None,
532
+ colormap=colormap,axes=axs[1])
533
+ visualize_splits_scatter(train_clusters=train_clusters,
534
+ val_clusters=val_clusters,
535
+ test_clusters=None,
536
+ benchmark_cluster_reps=benchmark_cluster_reps,
537
+ colormap=colormap, axes=axs[2, 0])
538
+ visualize_splits_aa_composition(train_sequences=train_sequences,
539
+ val_sequences=val_sequences,
540
+ test_sequences=None,
541
+ colormap=colormap, axes=axs[2, 1])
542
+
543
+ plt.tight_layout()
544
+ fig_combined.savefig('splits/combined_plot.png')
545
+ log_update(f"\nSaved combined figure to splits/combined_plot.png")