adding utility files used throughout FusOn-pLM training and benchmarking
Browse files- fuson_plm/utils/README.md +3 -0
- fuson_plm/utils/__init__.py +0 -0
- fuson_plm/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/clustering.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/constants.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/data_cleaning.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/embedding.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/logging.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/splitting.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/visualizing.cpython-310.pyc +0 -0
- fuson_plm/utils/clustering.py +139 -0
- fuson_plm/utils/constants.py +108 -0
- fuson_plm/utils/data_cleaning.py +126 -0
- fuson_plm/utils/embedding.py +193 -0
- fuson_plm/utils/logging.py +116 -0
- fuson_plm/utils/splitting.py +206 -0
- fuson_plm/utils/visualizing.py +545 -0
fuson_plm/utils/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
This folder contains common functions for data cleaning, clustering, train-test splitting, visualization, embedding, and logging.
|
2 |
+
|
3 |
+
The functions in these scripts are used throughout the pository for training the main model, FusOn-pLM, as well as benchmarks.
|
fuson_plm/utils/__init__.py
ADDED
File without changes
|
fuson_plm/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (149 Bytes). View file
|
|
fuson_plm/utils/__pycache__/clustering.cpython-310.pyc
ADDED
Binary file (4.87 kB). View file
|
|
fuson_plm/utils/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (2.48 kB). View file
|
|
fuson_plm/utils/__pycache__/data_cleaning.cpython-310.pyc
ADDED
Binary file (4.45 kB). View file
|
|
fuson_plm/utils/__pycache__/embedding.cpython-310.pyc
ADDED
Binary file (5.13 kB). View file
|
|
fuson_plm/utils/__pycache__/logging.cpython-310.pyc
ADDED
Binary file (3.31 kB). View file
|
|
fuson_plm/utils/__pycache__/splitting.cpython-310.pyc
ADDED
Binary file (6.95 kB). View file
|
|
fuson_plm/utils/__pycache__/visualizing.cpython-310.pyc
ADDED
Binary file (13.4 kB). View file
|
|
fuson_plm/utils/clustering.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
from Bio import SeqIO
|
6 |
+
import shutil
|
7 |
+
from fuson_plm.utils.logging import open_logfile, log_update
|
8 |
+
|
9 |
+
def ensure_mmseqs_in_path(mmseqs_dir):
|
10 |
+
"""
|
11 |
+
Checks if MMseqs2 is in the PATH. If it's not, add it. MMseqs2 will not run if this is not done correctly.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
mmseqs_dir (str): Directory containing MMseqs2 binaries
|
15 |
+
"""
|
16 |
+
mmseqs_bin = os.path.join(mmseqs_dir, 'mmseqs')
|
17 |
+
|
18 |
+
# Check if mmseqs is already in PATH
|
19 |
+
if shutil.which('mmseqs') is None:
|
20 |
+
# Export the MMseqs2 directory to PATH
|
21 |
+
os.environ['PATH'] = f"{mmseqs_dir}:{os.environ['PATH']}"
|
22 |
+
log_update(f"\tAdded {mmseqs_dir} to PATH")
|
23 |
+
|
24 |
+
def process_fasta(fasta_path):
|
25 |
+
fasta_sequences = SeqIO.parse(open(fasta_path),'fasta')
|
26 |
+
d = {}
|
27 |
+
for fasta in fasta_sequences:
|
28 |
+
id, sequence = fasta.id, str(fasta.seq)
|
29 |
+
|
30 |
+
d[id] = sequence
|
31 |
+
|
32 |
+
return d
|
33 |
+
|
34 |
+
def analyze_clustering_result(input_fasta: str, tsv_path: str):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
input_fasta (str): path to input fasta file
|
38 |
+
"""
|
39 |
+
|
40 |
+
# Process input fasta
|
41 |
+
input_d = process_fasta(input_fasta)
|
42 |
+
|
43 |
+
# Process clusters.tsv
|
44 |
+
clusters = pd.read_csv(f'{tsv_path}',sep='\t',header=None)
|
45 |
+
clusters = clusters.rename(columns={
|
46 |
+
0: 'representative seq_id',
|
47 |
+
1: 'member seq_id'
|
48 |
+
})
|
49 |
+
|
50 |
+
clusters['representative seq'] = clusters['representative seq_id'].apply(lambda seq_id: input_d[seq_id])
|
51 |
+
clusters['member seq'] = clusters['member seq_id'].apply(lambda seq_id: input_d[seq_id])
|
52 |
+
|
53 |
+
# Sort them so that splitting results are reproducible
|
54 |
+
clusters = clusters.sort_values(by=['representative seq_id','member seq_id'],ascending=True).reset_index(drop=True)
|
55 |
+
|
56 |
+
return clusters
|
57 |
+
|
58 |
+
def make_fasta(sequences: dict, fasta_path: str):
|
59 |
+
"""
|
60 |
+
Makes a fasta file from sequences, where the key is the header and the value is the sequence.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
sequences (dict): A dictionary where the key is the header and the value is the sequence.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
str: The path to the fasta file.
|
67 |
+
"""
|
68 |
+
with open(fasta_path, 'w') as f:
|
69 |
+
for header, sequence in sequences.items():
|
70 |
+
f.write(f'>{header}\n{sequence}\n')
|
71 |
+
|
72 |
+
return fasta_path
|
73 |
+
|
74 |
+
def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mode=0, cluster_mode=0, path_to_mmseqs='fuson_plm/mmseqs'):
|
75 |
+
"""
|
76 |
+
Runs MMSeqs2 clustering using easycluster module
|
77 |
+
|
78 |
+
Args:
|
79 |
+
input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence....
|
80 |
+
output_dir (str): path to output dir for clustering results
|
81 |
+
min_seq_id (float): number [0,1] representing --min-seq-id in cluster command
|
82 |
+
c (float): nunber [0,1] representing -c in cluster command
|
83 |
+
cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command
|
84 |
+
cluster_mode (int): number 0, 1, or 2 representing --cluster-mode in cluster command
|
85 |
+
|
86 |
+
"""
|
87 |
+
# Get mmseqs dir
|
88 |
+
log_update("\nRunning MMSeqs clustering...")
|
89 |
+
mmseqs_dir = os.path.join(path_to_mmseqs[0:path_to_mmseqs.index('/mmseqs')], 'mmseqs/bin')
|
90 |
+
|
91 |
+
# Ensure MMseqs2 is in the PATH
|
92 |
+
ensure_mmseqs_in_path(mmseqs_dir)
|
93 |
+
|
94 |
+
# Define paths for MMseqs2
|
95 |
+
mmseqs_bin = "mmseqs" # Ensure this is in your PATH or provide the full path to mmseqs binary
|
96 |
+
|
97 |
+
# Create the output directory
|
98 |
+
os.makedirs(output_dir, exist_ok=True)
|
99 |
+
|
100 |
+
# Run MMseqs2 easy-cluster
|
101 |
+
cmd_easy_cluster = [
|
102 |
+
mmseqs_bin, "easy-cluster", input_fasta, os.path.join(output_dir, "mmseqs"), output_dir,
|
103 |
+
"--min-seq-id", str(min_seq_id),
|
104 |
+
"-c", str(c),
|
105 |
+
"--cov-mode", str(cov_mode),
|
106 |
+
"--cluster-mode", str(cluster_mode),
|
107 |
+
"--dbtype", "1"
|
108 |
+
]
|
109 |
+
|
110 |
+
# Write the command to a log file
|
111 |
+
log_update("\n\tCommand entered to MMSeqs2:")
|
112 |
+
log_update("\t" + " ".join(cmd_easy_cluster) + "\n")
|
113 |
+
|
114 |
+
subprocess.run(cmd_easy_cluster, check=True)
|
115 |
+
|
116 |
+
log_update(f"Clustering completed. Results are in {output_dir}")
|
117 |
+
|
118 |
+
def cluster_summary(clusters: pd.DataFrame):
|
119 |
+
"""
|
120 |
+
Summarizes how many clusters were formed, how big they are, etc ...
|
121 |
+
"""
|
122 |
+
grouped_clusters = clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
123 |
+
assert len(grouped_clusters) == len(clusters['representative seq_id'].unique()) # make sure number of cluster reps = # grouped clusters
|
124 |
+
|
125 |
+
total_seqs = sum(grouped_clusters['member count'])
|
126 |
+
log_update(f"Created {len(grouped_clusters)} clusters of {total_seqs} sequences")
|
127 |
+
log_update(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1")
|
128 |
+
csize1_seqs = sum(grouped_clusters[grouped_clusters['member count']==1]['member count'])
|
129 |
+
log_update(f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)")
|
130 |
+
|
131 |
+
log_update(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1")
|
132 |
+
csizeg1_seqs = sum(grouped_clusters[grouped_clusters['member count']>1]['member count'])
|
133 |
+
log_update(f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)")
|
134 |
+
log_update(f"\tlargest cluster: {max(grouped_clusters['member count'])}")
|
135 |
+
|
136 |
+
log_update("\nCluster size breakdown below...")
|
137 |
+
|
138 |
+
value_counts = grouped_clusters['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
139 |
+
log_update(value_counts.sort_values(by='cluster size (n_members)',ascending=True).to_string(index=False))
|
fuson_plm/utils/constants.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Cleaning Parameters
|
2 |
+
# TCGA abbreviations for cancer. From https://gdc.cancer.gov/resources-tcga-users/tcga-code-tables/tcga-study-abbreviations
|
3 |
+
TCGA_CODES = {
|
4 |
+
'LAML': 'Acute Myeloid Leukemia',
|
5 |
+
'ACC': 'Adrenocortical carcinoma',
|
6 |
+
'BLCA': 'Bladder Urothelial Carcinoma',
|
7 |
+
'LGG': 'Brain Lower Grade Glioma',
|
8 |
+
'BRCA': 'Breast invasive carcinoma',
|
9 |
+
'CESC': 'Cervical squamous cell carcinoma and endocervical adenocarcinoma',
|
10 |
+
'CHOL': 'Cholangiocarcinoma',
|
11 |
+
'LCML': 'Chronic Myelogenous Leukemia',
|
12 |
+
'COAD': 'Colon adenocarcinoma',
|
13 |
+
'CNTL': 'Controls',
|
14 |
+
'ESCA': 'Esophageal carcinoma',
|
15 |
+
'FPPP': 'FFPE Pilot Phase II',
|
16 |
+
'GBM': 'Glioblastoma multiforme',
|
17 |
+
'HNSC': 'Head and Neck squamous cell carcinoma',
|
18 |
+
'KICH': 'Kidney Chromophobe',
|
19 |
+
'KIRC': 'Kidney renal clear cell carcinoma',
|
20 |
+
'KIRP': 'Kidney renal papillary cell carcinoma',
|
21 |
+
'LIHC': 'Liver hepatocellular carcinoma',
|
22 |
+
'LUAD': 'Lung adenocarcinoma',
|
23 |
+
'LUSC': 'Lung squamous cell carcinoma',
|
24 |
+
'DLBC': 'Lymphoid Neoplasm Diffuse Large B-cell Lymphoma',
|
25 |
+
'MESO': 'Mesothelioma',
|
26 |
+
'MISC': 'Miscellaneous',
|
27 |
+
'OV': 'Ovarian serous cystadenocarcinoma',
|
28 |
+
'PAAD': 'Pancreatic adenocarcinoma',
|
29 |
+
'PCPG': 'Pheochromocytoma and Paraganglioma',
|
30 |
+
'PRAD': 'Prostate adenocarcinoma',
|
31 |
+
'READ': 'Rectum adenocarcinoma',
|
32 |
+
'SARC': 'Sarcoma',
|
33 |
+
'SKCM': 'Skin Cutaneous Melanoma',
|
34 |
+
'STAD': 'Stomach adenocarcinoma',
|
35 |
+
'TGCT': 'Testicular Germ Cell Tumors',
|
36 |
+
'THYM': 'Thymoma',
|
37 |
+
'THCA': 'Thyroid carcinoma',
|
38 |
+
'UCS': 'Uterine Carcinosarcoma',
|
39 |
+
'UCEC': 'Uterine Corpus Endometrial Carcinoma',
|
40 |
+
'UVM': 'Uveal Melanoma'
|
41 |
+
}
|
42 |
+
|
43 |
+
FODB_CODES = {
|
44 |
+
'ACC': 'Adenoid cystic carcinoma',
|
45 |
+
'ALL': 'Acute Lymphoid Leukemia',
|
46 |
+
'AML': 'Acute Myeloid Leukemia',
|
47 |
+
'BALL': 'B-cell acute lymphoblastic leukemia',
|
48 |
+
'BLCA': 'Bladder Urothelial Carcinoma',
|
49 |
+
'BRCA': 'Breast invasive carcinoma',
|
50 |
+
'CESC': 'Cervical squamous cell carcinoma and endocervical adenocarcinoma',
|
51 |
+
'CHOL': 'Cholangiocarcinoma',
|
52 |
+
'EPD': 'Ependymoma',
|
53 |
+
'HGG': 'High-grade glioma',
|
54 |
+
'HNSC': 'Head and Neck squamous cell carcinoma',
|
55 |
+
'KIRC': 'Kidney renal clear cell carcinoma',
|
56 |
+
'LGG': 'Low-grade glioma',
|
57 |
+
'LUAD': 'Lung adenocarcinoma',
|
58 |
+
'LUSC': 'Lung squamous cell carcinoma',
|
59 |
+
'MEL': 'Melanoma',
|
60 |
+
'MESO': 'Mesothelioma',
|
61 |
+
'NBL': 'Neuroblastoma',
|
62 |
+
'OS': 'Osteosarcoma',
|
63 |
+
'OV': 'Ovarian serous cystadenocarcinoma',
|
64 |
+
'PCPG': 'Pheochromocytoma and Paraganglioma',
|
65 |
+
'PRAD': 'Prostate adenocarcinoma',
|
66 |
+
'READ': 'Rectum adenocarcinoma',
|
67 |
+
'RHB': 'Rhabdomyosarcoma',
|
68 |
+
'SARC': 'Sarcoma',
|
69 |
+
'STAD': 'Stomach adenocarcinoma',
|
70 |
+
'TALL': 'T-cell acute lymphoblastic leukemia',
|
71 |
+
'THYM': 'Thymoma',
|
72 |
+
'UCEC': 'Uterine Corpus Endometrial Carcinoma',
|
73 |
+
'UCS': 'Uterine Carcinosarcoma',
|
74 |
+
'UVM': 'Uveal Melanoma',
|
75 |
+
'WLM': 'Wilms tumor'
|
76 |
+
}
|
77 |
+
|
78 |
+
VALID_AAS = {'A',
|
79 |
+
'R',
|
80 |
+
'N',
|
81 |
+
'D',
|
82 |
+
'C',
|
83 |
+
'E',
|
84 |
+
'Q',
|
85 |
+
'G',
|
86 |
+
'H',
|
87 |
+
'I',
|
88 |
+
'L',
|
89 |
+
'K',
|
90 |
+
'M',
|
91 |
+
'F',
|
92 |
+
'P',
|
93 |
+
'S',
|
94 |
+
'T',
|
95 |
+
'W',
|
96 |
+
'Y',
|
97 |
+
'V'}
|
98 |
+
|
99 |
+
DELIMITERS = {',',
|
100 |
+
';',
|
101 |
+
'|',
|
102 |
+
'\t',
|
103 |
+
' ',
|
104 |
+
':',
|
105 |
+
'-',
|
106 |
+
'/',
|
107 |
+
'\\',
|
108 |
+
'\n'}
|
fuson_plm/utils/data_cleaning.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from fuson_plm.utils.logging import log_update
|
4 |
+
|
5 |
+
def clean_rows_and_cols(df: pd.Series) -> pd.Series:
|
6 |
+
"""
|
7 |
+
Deletes empty rows and columns
|
8 |
+
|
9 |
+
Args:
|
10 |
+
df (pd.Series): input DatFrame to be cleaned
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
pd.Series: cleaned DataFrame
|
14 |
+
"""
|
15 |
+
# Delete rows with no data
|
16 |
+
log_update(f"\trow cleaning...\n\t\toriginal # rows: {len(df)}")
|
17 |
+
log_update("\t\tdropping rows where all entries are np.nan...")
|
18 |
+
df = df.dropna(how='all')
|
19 |
+
log_update(f"\t\tnew # rows: {len(df)}")
|
20 |
+
|
21 |
+
# Delete columns with no data
|
22 |
+
log_update(f"\tcolumn cleaning...\n\t\toriginal # columns: {len(df.columns)}")
|
23 |
+
log_update("\t\tdropping columns where all entries are np.nan...")
|
24 |
+
df = df.dropna(axis=1,how='all')
|
25 |
+
log_update(f"\t\tnew # columns: {len(df.columns)}")
|
26 |
+
log_update(f"\t\tcolumn names: {','.join(list(df.columns))}")
|
27 |
+
|
28 |
+
return df
|
29 |
+
|
30 |
+
def check_columns_for_listlike(df: pd.DataFrame, cols_of_interest: list, delimiters: set):
|
31 |
+
"""
|
32 |
+
Checks if a column contains any listlike items
|
33 |
+
|
34 |
+
Args:
|
35 |
+
df (pd.DataFrame): DataFrame to be investigated
|
36 |
+
cols_of_interest (list): columns in df to be investigated for list-containing potential
|
37 |
+
delimiters (set): set of potential delimiting strings to search for. A column with any of these strings is considered listlike.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
dict: dictionary containing a set {} of all delimiters found in each column
|
41 |
+
e.g., { 'col1': {',',';'},
|
42 |
+
'col2': {'|'} }
|
43 |
+
"""
|
44 |
+
# return the delimiters/listlike things found for each column
|
45 |
+
return_dict = {}
|
46 |
+
|
47 |
+
log_update("\tchecking if any of our columns of interest look listlike (contain list objects or delimiters)...")
|
48 |
+
for col in cols_of_interest:
|
49 |
+
unique_col = list(df[col].value_counts().index)
|
50 |
+
listlike = any([check_item_for_listlike(x, delimiters) for x in unique_col])
|
51 |
+
|
52 |
+
if listlike:
|
53 |
+
found_delims = df[col].apply(lambda x: check_item_for_listlike(x, delimiters)).value_counts().reset_index()['index'].to_list()
|
54 |
+
unique_found_delims = set()
|
55 |
+
for x in found_delims:
|
56 |
+
unique_found_delims = unique_found_delims.union(x)
|
57 |
+
|
58 |
+
return_dict[col] = unique_found_delims
|
59 |
+
else:
|
60 |
+
return_dict[col] = False
|
61 |
+
|
62 |
+
# display the return dict
|
63 |
+
log_update(f"\t\tcolumn name: {col}\tlistlike: {return_dict[col]}")
|
64 |
+
|
65 |
+
return return_dict
|
66 |
+
|
67 |
+
def check_item_for_listlike(x, delimiters: set):
|
68 |
+
"""
|
69 |
+
Checks if a column looks like it contains a list of items, rather than an inidvidual item, based on string delimiters.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
x: the item to check. Any dtype.
|
73 |
+
delimiters: a set of delimiters to check for. e.g., {',', ';', '|', '\t', ' ', ':', '-', '/', '\\', '\n'}
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
If x is a string: the set (may be empty) of delimiters contained in the string
|
77 |
+
If x is not a string: the dtype of x
|
78 |
+
"""
|
79 |
+
if isinstance(x, str):
|
80 |
+
return find_delimiters(x, delimiters)
|
81 |
+
else:
|
82 |
+
if x is None:
|
83 |
+
# if it's None, it's not listlike, it's just empty. return {} because it has no delimiters.
|
84 |
+
return {}
|
85 |
+
if type(x)==float:
|
86 |
+
# if it's nan, it's not listlike, it's just empty. return {} because it has no delimiters.
|
87 |
+
if np.isnan(x):
|
88 |
+
return {}
|
89 |
+
return type(x)
|
90 |
+
|
91 |
+
def find_delimiters(seq: str, delimiters: set) -> set:
|
92 |
+
"""
|
93 |
+
Find and return a set of delimiters in a sequence. Helper mtehod for check_item_for_listlike.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
seq (str): The sequence you wish to search for invalid characters.
|
97 |
+
delimiters (set): a set of delimiters to check for. e.g., {',', ';', '|', '\t', ' ', ':', '-', '/', '\\', '\n'}
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
set: A set of characters in the sequence that are not in the set of valid characters.
|
101 |
+
"""
|
102 |
+
unique_chars = set(seq) # set of all characters in the sequence; unique_chars = {A, C} for protein="AAACCC"
|
103 |
+
overlap = delimiters.intersection(unique_chars)
|
104 |
+
|
105 |
+
if len(overlap)==0:
|
106 |
+
return {}
|
107 |
+
else:
|
108 |
+
return overlap
|
109 |
+
|
110 |
+
def find_invalid_chars(seq: str, valid_chars: set) -> set:
|
111 |
+
"""
|
112 |
+
Find and return a set of invalid characters in a sequence.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
seq (str): The sequence you wish to search for invalid characters.
|
116 |
+
valid_chars (set): A set of valid characters.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
set: A set of characters in the sequence that are not in the set of valid characters.
|
120 |
+
"""
|
121 |
+
unique_chars = set(seq) # set of all characters in the sequence; unique_chars = {A, C} for protein="AAACCC"
|
122 |
+
|
123 |
+
if unique_chars.issubset(valid_chars): # e.g. unique_chars = {A,C}, and {A,C} is a subset of valid_chars
|
124 |
+
return ''
|
125 |
+
else: # e.g. unique_chars = {A,X}. {A,X} is not a subset of valid_chars because X is not in valid_chars
|
126 |
+
return unique_chars.difference(valid_chars) # e.g. {A,X} - valid_chars = {X}
|
fuson_plm/utils/embedding.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
from transformers import EsmModel, AutoTokenizer
|
4 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
5 |
+
import pickle
|
6 |
+
import logging
|
7 |
+
from fuson_plm.utils.logging import log_update
|
8 |
+
|
9 |
+
|
10 |
+
def redump_pickle_dictionary(pickle_path):
|
11 |
+
"""
|
12 |
+
Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+'
|
13 |
+
"""
|
14 |
+
entries = {}
|
15 |
+
# Load one by one
|
16 |
+
with open(pickle_path, 'rb') as f:
|
17 |
+
while True:
|
18 |
+
try:
|
19 |
+
entry = pickle.load(f)
|
20 |
+
entries.update(entry)
|
21 |
+
except EOFError:
|
22 |
+
break # End of file reached
|
23 |
+
except Exception as e:
|
24 |
+
print(f"An error occurred: {e}")
|
25 |
+
break
|
26 |
+
# Redump
|
27 |
+
with open(pickle_path, 'wb') as f:
|
28 |
+
pickle.dump(entries, f)
|
29 |
+
|
30 |
+
def load_esm2_type(esm_type, device=None):
|
31 |
+
"""
|
32 |
+
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
|
33 |
+
"""
|
34 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
35 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
36 |
+
|
37 |
+
if device is None:
|
38 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
print(f"Using device: {device}")
|
40 |
+
|
41 |
+
model = EsmModel.from_pretrained(f"facebook/{esm_type}")
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
|
43 |
+
|
44 |
+
model.to(device)
|
45 |
+
model.eval() # disables dropout for deterministic results
|
46 |
+
|
47 |
+
return model, tokenizer, device
|
48 |
+
|
49 |
+
def load_prott5():
|
50 |
+
# Initialize tokenizer and model
|
51 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
52 |
+
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
|
53 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
|
54 |
+
if device == torch.device('cpu'):
|
55 |
+
model.to(torch.float32)
|
56 |
+
model.to(device)
|
57 |
+
return model, tokenizer, device
|
58 |
+
|
59 |
+
def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
|
60 |
+
"""
|
61 |
+
Compute ESM embeddings.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
model
|
65 |
+
tokenizer
|
66 |
+
sequences
|
67 |
+
device
|
68 |
+
average: if True, the average embeddings will be taken and returned
|
69 |
+
savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle
|
70 |
+
"""
|
71 |
+
# Correct save path to pickle if necessary
|
72 |
+
if savepath is not None:
|
73 |
+
if savepath[-4::] != '.pkl': savepath += '.pkl'
|
74 |
+
|
75 |
+
# If no max length was passed, just set it to the maximum in the dataset
|
76 |
+
max_seq_len = max([len(s) for s in sequences])
|
77 |
+
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
|
78 |
+
|
79 |
+
# Initialize an empty dict to store the ESM embeddings
|
80 |
+
embedding_dict = {}
|
81 |
+
# Iterate through the seqs
|
82 |
+
for i in range(len(sequences)):
|
83 |
+
sequence = sequences[i]
|
84 |
+
# Get the embeddings
|
85 |
+
with torch.no_grad():
|
86 |
+
inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length)
|
87 |
+
inputs = {k:v.to(device) for k, v in inputs.items()}
|
88 |
+
|
89 |
+
outputs = model(**inputs)
|
90 |
+
embedding = outputs.last_hidden_state
|
91 |
+
|
92 |
+
# remove extra dimension
|
93 |
+
embedding = embedding.squeeze(0)
|
94 |
+
# remove BOS and EOS tokens
|
95 |
+
embedding = embedding[1:-1, :]
|
96 |
+
|
97 |
+
# Convert embeddings to numpy array (if needed)
|
98 |
+
embedding = embedding.cpu().numpy()
|
99 |
+
|
100 |
+
# Average (if necessary)
|
101 |
+
if average:
|
102 |
+
embedding = embedding.mean(0)
|
103 |
+
|
104 |
+
# Add to dictionary
|
105 |
+
embedding_dict[sequence] = embedding
|
106 |
+
|
107 |
+
# Save individual embedding (if necessary)
|
108 |
+
if not(savepath is None) and not(save_at_end):
|
109 |
+
with open(savepath, 'ab+') as f:
|
110 |
+
d = {sequence: embedding}
|
111 |
+
pickle.dump(d, f)
|
112 |
+
|
113 |
+
# Print update (if necessary)
|
114 |
+
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
|
115 |
+
|
116 |
+
# Dump all at once at the end (if necessary)
|
117 |
+
if not(savepath is None):
|
118 |
+
# If saving for the first time, just dump it
|
119 |
+
if save_at_end:
|
120 |
+
with open(savepath, 'wb') as f:
|
121 |
+
pickle.dump(embedding_dict, f)
|
122 |
+
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
|
123 |
+
else:
|
124 |
+
redump_pickle_dictionary(savepath)
|
125 |
+
|
126 |
+
# Return the dictionary
|
127 |
+
return embedding_dict
|
128 |
+
|
129 |
+
def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
|
130 |
+
# Correct save path to pickle if necessary
|
131 |
+
if savepath is not None:
|
132 |
+
if savepath[-4::] != '.pkl': savepath += '.pkl'
|
133 |
+
|
134 |
+
# If no max length was passed, just set it to the maximum in the dataset
|
135 |
+
max_seq_len = max([len(s) for s in sequences])
|
136 |
+
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
|
137 |
+
|
138 |
+
# the ProtT5 tokenizer requires that there are spaces between residues
|
139 |
+
spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer
|
140 |
+
|
141 |
+
# Store embeddings here
|
142 |
+
embedding_dict = {} # store embeddings here
|
143 |
+
|
144 |
+
for i in range(0, len(spaced_sequences)):
|
145 |
+
spaced_sequence = spaced_sequences[i] # get current sequence
|
146 |
+
seq = spaced_sequence.replace(" ", "")
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1
|
150 |
+
inputs = {k:v.to(device) for k, v in inputs.items()}
|
151 |
+
|
152 |
+
# Pass through the model with no gradient to get embeddings
|
153 |
+
embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
154 |
+
|
155 |
+
# Process the embedding
|
156 |
+
seq_length = len(seq) # length of the sequence is after you remove spaces
|
157 |
+
embedding = embedding_repr.last_hidden_state.squeeze(0) # remove batch dimension
|
158 |
+
embedding = embedding[0:-1] # remove EOS token (there is no BOS token)
|
159 |
+
embedding = embedding.cpu().numpy() # put on CPU and numpy
|
160 |
+
embedding_log = f"\tembedding shape: {embedding.shape}"
|
161 |
+
# MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length
|
162 |
+
assert embedding.shape[1] == 1024
|
163 |
+
assert embedding.shape[0] == seq_length
|
164 |
+
|
165 |
+
# Average (if necessary)
|
166 |
+
if average:
|
167 |
+
dim_before = embedding.shape
|
168 |
+
embedding = embedding.mean(0)
|
169 |
+
embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}"
|
170 |
+
|
171 |
+
# Add the embedding to the dictionary
|
172 |
+
embedding_dict[seq] = embedding
|
173 |
+
|
174 |
+
# Save individual embedding (if necessary)
|
175 |
+
if not(savepath is None) and not(save_at_end):
|
176 |
+
with open(savepath, 'ab+') as f:
|
177 |
+
d = {seq: embedding}
|
178 |
+
pickle.dump(d, f)
|
179 |
+
|
180 |
+
if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}")
|
181 |
+
|
182 |
+
# Dump all at once at the end (if necessary)
|
183 |
+
if not(savepath is None):
|
184 |
+
# If saving for the first time, just dump it
|
185 |
+
if save_at_end:
|
186 |
+
with open(savepath, 'wb') as f:
|
187 |
+
pickle.dump(embedding_dict, f)
|
188 |
+
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
|
189 |
+
else:
|
190 |
+
redump_pickle_dictionary(savepath)
|
191 |
+
|
192 |
+
# Return the dictionary
|
193 |
+
return embedding_dict
|
fuson_plm/utils/logging.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from contextlib import contextmanager
|
3 |
+
import sys
|
4 |
+
import pytz
|
5 |
+
import os
|
6 |
+
|
7 |
+
class CustomParams:
|
8 |
+
"""
|
9 |
+
Class for custom parameters where dictionary elements can be accessed as attributes
|
10 |
+
"""
|
11 |
+
def __init__(self, **kwargs):
|
12 |
+
self.__dict__.update(kwargs)
|
13 |
+
|
14 |
+
def print_config(self,indent=''):
|
15 |
+
for attr, value in self.__dict__.items():
|
16 |
+
print(f"{indent}{attr}: {value}")
|
17 |
+
|
18 |
+
def log_update(text: str):
|
19 |
+
"""
|
20 |
+
Logs input text to an output file
|
21 |
+
|
22 |
+
Args:
|
23 |
+
text (str): the text to be logged
|
24 |
+
"""
|
25 |
+
print(text) # log_update the text
|
26 |
+
sys.stdout.flush() # flush to automatically update the output file
|
27 |
+
|
28 |
+
@contextmanager
|
29 |
+
def open_logfile(log_path,mode='w'):
|
30 |
+
"""
|
31 |
+
Open log-file for real-time logging of the most important updates
|
32 |
+
"""
|
33 |
+
log_file = open(log_path, mode) # open
|
34 |
+
original_stdout = sys.stdout # save original stdout
|
35 |
+
sys.stdout = log_file # redirect stdout to log_file
|
36 |
+
try:
|
37 |
+
yield log_file
|
38 |
+
finally:
|
39 |
+
sys.stdout = original_stdout
|
40 |
+
log_file.close()
|
41 |
+
|
42 |
+
@contextmanager
|
43 |
+
def open_errfile(log_path,mode='w'):
|
44 |
+
"""
|
45 |
+
Redirects stderr (error messages) to a separate log file.
|
46 |
+
"""
|
47 |
+
log_file = open(log_path, mode) # open the error log file for writing
|
48 |
+
original_stderr = sys.stderr # save original stderr
|
49 |
+
sys.stderr = log_file # redirect stderr to log_file
|
50 |
+
try:
|
51 |
+
yield log_file
|
52 |
+
finally:
|
53 |
+
sys.stderr = original_stderr # restore original stderr
|
54 |
+
log_file.close() # close the error log file
|
55 |
+
|
56 |
+
def print_configpy(module):
|
57 |
+
"""
|
58 |
+
Prints all the configurations in a config.py file
|
59 |
+
"""
|
60 |
+
log_update("All configurations:")
|
61 |
+
# Iterate over attributes
|
62 |
+
for attribute in dir(module):
|
63 |
+
# Filter out built-in attributes and methods
|
64 |
+
if not attribute.startswith("__"):
|
65 |
+
value = getattr(module, attribute)
|
66 |
+
log_update(f"\t{attribute}: {value}")
|
67 |
+
|
68 |
+
def get_local_time(timezone_str='US/Eastern'):
|
69 |
+
"""
|
70 |
+
Get current time in the specified timezone.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
str: The formatted current time in the specified timezone.
|
77 |
+
"""
|
78 |
+
try:
|
79 |
+
timezone = pytz.timezone(timezone_str)
|
80 |
+
except pytz.UnknownTimeZoneError:
|
81 |
+
return f"Unknown timezone: {timezone_str}"
|
82 |
+
|
83 |
+
current_datetime = datetime.now(pytz.utc).astimezone(timezone)
|
84 |
+
return current_datetime.strftime('%m-%d-%Y-%H:%M:%S')
|
85 |
+
|
86 |
+
def get_local_date_yr(timezone_str='US/Eastern'):
|
87 |
+
"""
|
88 |
+
Get current time in the specified timezone.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
str: The formatted current time in the specified timezone.
|
95 |
+
"""
|
96 |
+
try:
|
97 |
+
timezone = pytz.timezone(timezone_str)
|
98 |
+
except pytz.UnknownTimeZoneError:
|
99 |
+
return f"Unknown timezone: {timezone_str}"
|
100 |
+
|
101 |
+
current_datetime = datetime.now(pytz.utc).astimezone(timezone)
|
102 |
+
return current_datetime.strftime('%m_%d_%Y')
|
103 |
+
|
104 |
+
def find_fuson_plm_directory():
|
105 |
+
"""
|
106 |
+
Constructs a path backwards to fuson_plm directory so we don't have to use absolute paths (helps for docker containers)
|
107 |
+
"""
|
108 |
+
current_dir = os.path.abspath(os.getcwd())
|
109 |
+
|
110 |
+
while True:
|
111 |
+
if 'fuson_plm' in os.listdir(current_dir):
|
112 |
+
return os.path.join(current_dir, 'fuson_plm')
|
113 |
+
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
|
114 |
+
if parent_dir == current_dir: # If we've reached the root directory
|
115 |
+
raise FileNotFoundError("fuson_plm directory not found.")
|
116 |
+
current_dir = parent_dir
|
fuson_plm/utils/splitting.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from sklearn.model_selection import train_test_split
|
3 |
+
from fuson_plm.utils.logging import log_update
|
4 |
+
|
5 |
+
def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20):
|
6 |
+
# cluster with random state fixed for reproducible results
|
7 |
+
log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})")
|
8 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
9 |
+
|
10 |
+
# add benchmark representatives back to X_test
|
11 |
+
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
|
12 |
+
X_test += benchmark_cluster_reps
|
13 |
+
|
14 |
+
# assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
|
15 |
+
assert len(X_train)==len(set(X_train))
|
16 |
+
assert len(X_test)==len(set(X_test))
|
17 |
+
|
18 |
+
return {
|
19 |
+
'X_train': X_train,
|
20 |
+
'X_test': X_test
|
21 |
+
}
|
22 |
+
|
23 |
+
def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
|
24 |
+
# cluster with random state fixed for reproducible results
|
25 |
+
log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
|
26 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1)
|
27 |
+
log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
|
28 |
+
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2)
|
29 |
+
|
30 |
+
# add benchmark representatives back to X_test
|
31 |
+
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
|
32 |
+
X_test += benchmark_cluster_reps
|
33 |
+
|
34 |
+
# assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
|
35 |
+
assert len(X_train)==len(set(X_train))
|
36 |
+
assert len(X_val)==len(set(X_val))
|
37 |
+
assert len(X_test)==len(set(X_test))
|
38 |
+
|
39 |
+
return {
|
40 |
+
'X_train': X_train,
|
41 |
+
'X_val': X_val,
|
42 |
+
'X_test': X_test
|
43 |
+
}
|
44 |
+
|
45 |
+
def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
|
46 |
+
""""
|
47 |
+
Cluster-splitting method amenable to either train-test or train-val-test.
|
48 |
+
For train-val-test, there are two splits.
|
49 |
+
"""
|
50 |
+
log_update("\nPerforming splits...")
|
51 |
+
# Approx. 80/10/10 split
|
52 |
+
X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)] # X, for splitting, does NOT include benchmark reps. We'll add these clusters to test.
|
53 |
+
y = [0]*len(X) # y is a dummy array here; there are no values.
|
54 |
+
|
55 |
+
split_dict = None
|
56 |
+
if val_set:
|
57 |
+
split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
|
58 |
+
random_state_1 = random_state_1, random_state_2 = random_state_2,
|
59 |
+
test_size_1 = test_size_1, test_size_2 = test_size_2)
|
60 |
+
else:
|
61 |
+
split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
|
62 |
+
random_state = random_state_1,
|
63 |
+
test_size = test_size_1)
|
64 |
+
|
65 |
+
return split_dict
|
66 |
+
|
67 |
+
def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
train_clusters (pd.DataFrame):
|
71 |
+
val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set)
|
72 |
+
test_clusters (pd.DataFrame):
|
73 |
+
"""
|
74 |
+
|
75 |
+
# Make grouped versions of these DataFrames for size analysis
|
76 |
+
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
77 |
+
if val_clusters is not None:
|
78 |
+
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
79 |
+
if test_clusters is not None:
|
80 |
+
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
81 |
+
|
82 |
+
# Calculate stats - clusters
|
83 |
+
n_train_clusters = len(train_clustersgb)
|
84 |
+
n_val_clusters, n_test_clusters = 0, 0
|
85 |
+
if val_clusters is not None:
|
86 |
+
n_val_clusters = len(val_clustersgb)
|
87 |
+
if test_clusters is not None:
|
88 |
+
n_test_clusters = len(test_clustersgb)
|
89 |
+
n_clusters = n_train_clusters + n_val_clusters + n_test_clusters
|
90 |
+
|
91 |
+
assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb)
|
92 |
+
if val_clusters is not None:
|
93 |
+
assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb)
|
94 |
+
if test_clusters is not None:
|
95 |
+
assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb)
|
96 |
+
|
97 |
+
train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2)
|
98 |
+
if val_clusters is not None:
|
99 |
+
val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2)
|
100 |
+
if test_clusters is not None:
|
101 |
+
test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2)
|
102 |
+
|
103 |
+
# Calculate stats - proteins
|
104 |
+
n_train_proteins = len(train_clusters)
|
105 |
+
n_val_proteins, n_test_proteins = 0, 0
|
106 |
+
if val_clusters is not None:
|
107 |
+
n_val_proteins = len(val_clusters)
|
108 |
+
if test_clusters is not None:
|
109 |
+
n_test_proteins = len(test_clusters)
|
110 |
+
n_proteins = n_train_proteins + n_val_proteins + n_test_proteins
|
111 |
+
|
112 |
+
assert len(train_clusters) == sum(train_clustersgb['member count'])
|
113 |
+
if val_clusters is not None:
|
114 |
+
assert len(val_clusters) == sum(val_clustersgb['member count'])
|
115 |
+
if test_clusters is not None:
|
116 |
+
assert len(test_clusters) == sum(test_clustersgb['member count'])
|
117 |
+
|
118 |
+
train_protein_pcnt = round(100*n_train_proteins/n_proteins,2)
|
119 |
+
if val_clusters is not None:
|
120 |
+
val_protein_pcnt = round(100*n_val_proteins/n_proteins,2)
|
121 |
+
if test_clusters is not None:
|
122 |
+
test_protein_pcnt = round(100*n_test_proteins/n_proteins,2)
|
123 |
+
|
124 |
+
# Print results
|
125 |
+
log_update("\nCluster breakdown...")
|
126 |
+
log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}")
|
127 |
+
log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)")
|
128 |
+
if val_clusters is not None:
|
129 |
+
log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)")
|
130 |
+
if test_clusters is not None:
|
131 |
+
log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)")
|
132 |
+
|
133 |
+
# Check for overlap in both sequence ID and sequence actual
|
134 |
+
train_protein_ids = set(train_clusters['member seq_id'])
|
135 |
+
train_protein_seqs = set(train_clusters['member seq'])
|
136 |
+
if val_clusters is not None:
|
137 |
+
val_protein_ids = set(val_clusters['member seq_id'])
|
138 |
+
val_protein_seqs = set(val_clusters['member seq'])
|
139 |
+
if test_clusters is not None:
|
140 |
+
test_protein_ids = set(test_clusters['member seq_id'])
|
141 |
+
test_protein_seqs = set(test_clusters['member seq'])
|
142 |
+
|
143 |
+
# Print results
|
144 |
+
log_update("\nChecking for overlap...")
|
145 |
+
if (val_clusters is not None) and (test_clusters is not None):
|
146 |
+
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}")
|
147 |
+
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}")
|
148 |
+
if (val_clusters is not None) and (test_clusters is None):
|
149 |
+
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}")
|
150 |
+
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}")
|
151 |
+
if (val_clusters is None) and (test_clusters is not None):
|
152 |
+
log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}")
|
153 |
+
log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}")
|
154 |
+
|
155 |
+
# Assert no sequence overlap
|
156 |
+
if val_clusters is not None:
|
157 |
+
assert len(train_protein_seqs.intersection(val_protein_seqs))==0
|
158 |
+
if test_clusters is not None:
|
159 |
+
assert len(train_protein_seqs.intersection(test_protein_seqs))==0
|
160 |
+
if (val_clusters is not None) and (test_clusters is not None):
|
161 |
+
assert len(val_protein_seqs.intersection(test_protein_seqs))==0
|
162 |
+
|
163 |
+
# Finally, check that there are only benchmark sequences in test - if there are benchmark sequences
|
164 |
+
if not(benchmark_sequences is None):
|
165 |
+
bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
|
166 |
+
bench_in_val, bench_in_test = 0, 0
|
167 |
+
if val_clusters is not None:
|
168 |
+
bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
|
169 |
+
if test_clusters is not None:
|
170 |
+
bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
|
171 |
+
|
172 |
+
# Assert this
|
173 |
+
log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...")
|
174 |
+
log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}")
|
175 |
+
log_update(f"\tBenchmark sequences in train: {bench_in_train}")
|
176 |
+
if val_clusters is not None:
|
177 |
+
log_update(f"\tBenchmark sequences in val: {bench_in_val}")
|
178 |
+
if test_clusters is not None:
|
179 |
+
log_update(f"\tBenchmark sequences in test: {bench_in_test}")
|
180 |
+
assert bench_in_train == bench_in_val == 0
|
181 |
+
assert bench_in_test == len(benchmark_sequences)
|
182 |
+
|
183 |
+
def check_class_distributions(train_df, val_df, test_df, class_col='class'):
|
184 |
+
"""
|
185 |
+
Checks class distributions within train, val, and test sets.
|
186 |
+
Expects input dataframes to have 'sequence' column and 'class' column
|
187 |
+
"""
|
188 |
+
train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'})
|
189 |
+
train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100
|
190 |
+
if val_df is not None:
|
191 |
+
val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'})
|
192 |
+
val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100
|
193 |
+
test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'})
|
194 |
+
test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100
|
195 |
+
# concatenate so I can see them next to each other
|
196 |
+
if val_df is not None:
|
197 |
+
compare = pd.concat([train_vc, val_vc, test_vc], axis=1)
|
198 |
+
compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x))
|
199 |
+
compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x))
|
200 |
+
else:
|
201 |
+
compare = pd.concat([train_vc, test_vc], axis=1)
|
202 |
+
compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x))
|
203 |
+
|
204 |
+
compare_str = compare.to_string(index=False)
|
205 |
+
compare_str = "\t" + compare_str.replace("\n","\n\t")
|
206 |
+
log_update(f"\nClass distribution:\n{compare_str}")
|
fuson_plm/utils/visualizing.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import matplotlib.font_manager as fm
|
3 |
+
from matplotlib.font_manager import FontProperties
|
4 |
+
from scipy.stats import entropy
|
5 |
+
from sklearn.manifold import TSNE
|
6 |
+
import pickle
|
7 |
+
import pandas as pd
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
from fuson_plm.utils.logging import log_update, find_fuson_plm_directory
|
11 |
+
|
12 |
+
def set_font():
|
13 |
+
# Load and set the font
|
14 |
+
fuson_plm_dir = find_fuson_plm_directory()
|
15 |
+
|
16 |
+
# Paths for regular, bold, italic fonts
|
17 |
+
regular_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Regular.ttf')
|
18 |
+
bold_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Bold.ttf')
|
19 |
+
italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Italic.ttf')
|
20 |
+
bold_italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-BoldItalic.ttf')
|
21 |
+
|
22 |
+
# Load the font properties
|
23 |
+
regular_font = FontProperties(fname=regular_font_path)
|
24 |
+
bold_font = FontProperties(fname=bold_font_path)
|
25 |
+
italic_font = FontProperties(fname=italic_font_path)
|
26 |
+
bold_italic_font = FontProperties(fname=bold_italic_font_path)
|
27 |
+
|
28 |
+
# Add the fonts to the font manager
|
29 |
+
fm.fontManager.addfont(regular_font_path)
|
30 |
+
fm.fontManager.addfont(bold_font_path)
|
31 |
+
fm.fontManager.addfont(italic_font_path)
|
32 |
+
fm.fontManager.addfont(bold_italic_font_path)
|
33 |
+
|
34 |
+
# Set the font family globally to Ubuntu
|
35 |
+
plt.rcParams['font.family'] = regular_font.get_name()
|
36 |
+
|
37 |
+
# Set the fonts for math text (like for labels) to use the loaded Ubuntu fonts
|
38 |
+
plt.rcParams['mathtext.fontset'] = 'custom'
|
39 |
+
plt.rcParams['mathtext.rm'] = regular_font.get_name()
|
40 |
+
plt.rcParams['mathtext.it'] = f'{italic_font.get_name()}'
|
41 |
+
plt.rcParams['mathtext.bf'] = f'{bold_font.get_name()}'
|
42 |
+
|
43 |
+
global default_color_map
|
44 |
+
default_color_map = {
|
45 |
+
'train': '#0072B2',
|
46 |
+
'val': '#009E73',
|
47 |
+
'test': '#E69F00'
|
48 |
+
}
|
49 |
+
|
50 |
+
def get_avg_embeddings_for_tsne(train_sequences=None, val_sequences=None, test_sequences=None, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'):
|
51 |
+
if train_sequences is None: train_sequences = []
|
52 |
+
if val_sequences is None: val_sequences = []
|
53 |
+
if test_sequences is None: test_sequences = []
|
54 |
+
|
55 |
+
embeddings = {}
|
56 |
+
|
57 |
+
try:
|
58 |
+
with open(embedding_path, 'rb') as f:
|
59 |
+
embeddings = pickle.load(f)
|
60 |
+
|
61 |
+
train_embeddings = [v for k, v in embeddings.items() if k in train_sequences]
|
62 |
+
val_embeddings = [v for k, v in embeddings.items() if k in val_sequences]
|
63 |
+
test_embeddings = [v for k, v in embeddings.items() if k in test_sequences]
|
64 |
+
|
65 |
+
return train_embeddings, val_embeddings, test_embeddings
|
66 |
+
except:
|
67 |
+
print("could not open embeddings")
|
68 |
+
|
69 |
+
|
70 |
+
def calculate_aa_composition(sequences):
|
71 |
+
composition = {}
|
72 |
+
total_length = sum([len(seq) for seq in sequences])
|
73 |
+
|
74 |
+
for seq in sequences:
|
75 |
+
for aa in seq:
|
76 |
+
if aa in composition:
|
77 |
+
composition[aa] += 1
|
78 |
+
else:
|
79 |
+
composition[aa] = 1
|
80 |
+
|
81 |
+
# Convert counts to relative frequency
|
82 |
+
for aa in composition:
|
83 |
+
composition[aa] /= total_length
|
84 |
+
|
85 |
+
return composition
|
86 |
+
|
87 |
+
def calculate_shannon_entropy(sequence):
|
88 |
+
"""
|
89 |
+
Calculate the Shannon entropy for a given sequence.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
sequence (str): A sequence of characters (e.g., amino acids or nucleotides).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
float: Shannon entropy value.
|
96 |
+
"""
|
97 |
+
bases = set(sequence)
|
98 |
+
counts = [sequence.count(base) for base in bases]
|
99 |
+
return entropy(counts, base=2)
|
100 |
+
|
101 |
+
def visualize_splits_hist(train_lengths=None, val_lengths=None, test_lengths=None, colormap=None, savepath=f'splits/length_distributions.png', axes=None):
|
102 |
+
"""
|
103 |
+
Works to plot train, val, test; train, val; or train, test
|
104 |
+
"""
|
105 |
+
set_font()
|
106 |
+
if colormap is None: colormap=default_color_map
|
107 |
+
|
108 |
+
log_update('\nMaking histogram of length distributions')
|
109 |
+
|
110 |
+
# Get index for test plot
|
111 |
+
val_plot_index, test_plot_index, total_plots = 1, 2, 3
|
112 |
+
if val_lengths is None:
|
113 |
+
val_plot_index = None
|
114 |
+
test_plot_index-= 1
|
115 |
+
total_plots-=1
|
116 |
+
if test_lengths is None:
|
117 |
+
test_plot_index = None
|
118 |
+
total_plots-=1
|
119 |
+
|
120 |
+
# Create a figure and axes with 1 row and 3 columns
|
121 |
+
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
|
122 |
+
|
123 |
+
# Set axes list
|
124 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
125 |
+
|
126 |
+
# Unpack the labels and titles
|
127 |
+
xlabel, ylabel = ['Sequence Length (AA)', 'Frequency']
|
128 |
+
|
129 |
+
for cur_axes in axes_list:
|
130 |
+
# Plot the first histogram
|
131 |
+
cur_axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train'])
|
132 |
+
cur_axes[0].set_xlabel(xlabel)
|
133 |
+
cur_axes[0].set_ylabel(ylabel)
|
134 |
+
cur_axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})')
|
135 |
+
cur_axes[0].grid(True)
|
136 |
+
cur_axes[0].set_axisbelow(True)
|
137 |
+
|
138 |
+
# Plot the second histogram
|
139 |
+
if not(val_plot_index is None):
|
140 |
+
cur_axes[val_plot_index].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val'])
|
141 |
+
cur_axes[val_plot_index].set_xlabel(xlabel)
|
142 |
+
cur_axes[val_plot_index].set_ylabel(ylabel)
|
143 |
+
cur_axes[val_plot_index].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})')
|
144 |
+
cur_axes[val_plot_index].grid(True)
|
145 |
+
cur_axes[val_plot_index].set_axisbelow(True)
|
146 |
+
|
147 |
+
# Plot the third histogram
|
148 |
+
if not(test_plot_index is None):
|
149 |
+
cur_axes[test_plot_index].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test'])
|
150 |
+
cur_axes[test_plot_index].set_xlabel(xlabel)
|
151 |
+
cur_axes[test_plot_index].set_ylabel(ylabel)
|
152 |
+
cur_axes[test_plot_index].set_title(f'Test Set Length Distribution (n={len(test_lengths)})')
|
153 |
+
cur_axes[test_plot_index].grid(True)
|
154 |
+
cur_axes[test_plot_index].set_axisbelow(True)
|
155 |
+
|
156 |
+
# Adjust layout
|
157 |
+
fig_individual.set_tight_layout(True)
|
158 |
+
|
159 |
+
# Save the figure
|
160 |
+
fig_individual.savefig(savepath)
|
161 |
+
log_update(f"\tSaved figure to {savepath}")
|
162 |
+
|
163 |
+
def visualize_splits_scatter(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, colormap=None, savepath='splits/scatterplot.png', axes=None):
|
164 |
+
set_font()
|
165 |
+
if colormap is None: colormap=default_color_map
|
166 |
+
|
167 |
+
# Create a figure and axes with 1 row and 3 columns
|
168 |
+
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
|
169 |
+
|
170 |
+
# Set axes list
|
171 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
172 |
+
|
173 |
+
log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val")
|
174 |
+
# Make grouped versions of these DataFrames for size analysis
|
175 |
+
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
176 |
+
if not(val_clusters is None):
|
177 |
+
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
178 |
+
if not(test_clusters is None):
|
179 |
+
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
180 |
+
# Isolate benchmark-containing clusters so their contribution can be plotted separately
|
181 |
+
total_test_proteins = sum(test_clustersgb['member count'])
|
182 |
+
if not(benchmark_cluster_reps is None):
|
183 |
+
test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps)
|
184 |
+
benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True)
|
185 |
+
test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True)
|
186 |
+
|
187 |
+
# Convert them to value counts
|
188 |
+
train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
189 |
+
if not(val_clusters is None):
|
190 |
+
val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
191 |
+
if not(test_clusters is None):
|
192 |
+
test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
193 |
+
if not(benchmark_cluster_reps is None):
|
194 |
+
benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
195 |
+
|
196 |
+
# Get the percentage of each dataset that's made of each cluster size
|
197 |
+
train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins
|
198 |
+
train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins'])
|
199 |
+
if not(val_clusters is None):
|
200 |
+
val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters']
|
201 |
+
val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins'])
|
202 |
+
if not(test_clusters is None):
|
203 |
+
test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters']
|
204 |
+
test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins
|
205 |
+
if not(benchmark_cluster_reps is None):
|
206 |
+
benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters']
|
207 |
+
benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins
|
208 |
+
|
209 |
+
# Specially mark the benchmark clusters because these can't be reallocated
|
210 |
+
for ax in axes_list:
|
211 |
+
ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train')
|
212 |
+
if not(val_clusters is None):
|
213 |
+
ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val')
|
214 |
+
if not(test_clusters is None):
|
215 |
+
ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test')
|
216 |
+
if not(benchmark_cluster_reps is None):
|
217 |
+
ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'],
|
218 |
+
marker='o',
|
219 |
+
linestyle='None',
|
220 |
+
markerfacecolor=colormap['test'], # fill same as test
|
221 |
+
markeredgecolor='black', # outline black
|
222 |
+
markeredgewidth=1.5,
|
223 |
+
label='benchmark'
|
224 |
+
)
|
225 |
+
ax.set(ylabel='Percentage of Proteins in Dataset',xlabel='cluster_size')
|
226 |
+
ax.legend()
|
227 |
+
|
228 |
+
# save the figure
|
229 |
+
fig_individual.set_tight_layout(True)
|
230 |
+
fig_individual.savefig(savepath)
|
231 |
+
log_update(f"\tSaved figure to {savepath}")
|
232 |
+
|
233 |
+
|
234 |
+
def visualize_splits_tsne(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='splits/tsne_plot.png',axes=None):
|
235 |
+
set_font()
|
236 |
+
|
237 |
+
if colormap is None: colormap=default_color_map
|
238 |
+
|
239 |
+
"""
|
240 |
+
Generate a t-SNE plot of embeddings for train, test, and validation.
|
241 |
+
"""
|
242 |
+
log_update('\nMaking t-SNE plot of train, val, and test embeddings')
|
243 |
+
# Create a figure and axes with 1 row and 3 columns
|
244 |
+
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
|
245 |
+
|
246 |
+
# Set axes list
|
247 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
248 |
+
|
249 |
+
# Combine the embeddings into one array
|
250 |
+
train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences=train_sequences,
|
251 |
+
val_sequences=val_sequences,
|
252 |
+
test_sequences=test_sequences, embedding_path=embedding_path)
|
253 |
+
if not(val_embeddings is None) and not(test_embeddings is None):
|
254 |
+
embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings])
|
255 |
+
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings)
|
256 |
+
if not(val_embeddings is None) and (test_embeddings is None):
|
257 |
+
embeddings = np.concatenate([train_embeddings, val_embeddings])
|
258 |
+
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings)
|
259 |
+
if (val_embeddings is None) and not(test_embeddings is None):
|
260 |
+
embeddings = np.concatenate([train_embeddings, test_embeddings])
|
261 |
+
labels = ['train'] * len(train_embeddings) + ['test'] * len(test_embeddings)
|
262 |
+
|
263 |
+
# Perform t-SNE
|
264 |
+
tsne = TSNE(n_components=2, random_state=42)
|
265 |
+
tsne_results = tsne.fit_transform(embeddings)
|
266 |
+
|
267 |
+
# Convert t-SNE results into a DataFrame
|
268 |
+
tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2'])
|
269 |
+
tsne_df['label'] = labels
|
270 |
+
|
271 |
+
for ax in axes_list:
|
272 |
+
# Scatter plot for each set
|
273 |
+
for label, color in colormap.items():
|
274 |
+
subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True)
|
275 |
+
ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6)
|
276 |
+
|
277 |
+
ax.set_title(f't-SNE of {esm_type} Embeddings')
|
278 |
+
ax.set_xlabel('t-SNE Dimension 1')
|
279 |
+
ax.set_ylabel('t-SNE Dimension 2')
|
280 |
+
ax.legend()
|
281 |
+
ax.grid(True)
|
282 |
+
|
283 |
+
# Save the figure if savepath is provided
|
284 |
+
fig_individual.set_tight_layout(True)
|
285 |
+
fig_individual.savefig(savepath)
|
286 |
+
log_update(f"\tSaved figure to {savepath}")
|
287 |
+
|
288 |
+
def visualize_splits_shannon_entropy(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/shannon_entropy_plot.png',axes=None):
|
289 |
+
set_font()
|
290 |
+
"""
|
291 |
+
Generate Shannon entropy plots for train, validation, and test sets.
|
292 |
+
"""
|
293 |
+
# Get index for test plot
|
294 |
+
val_plot_index, test_plot_index, total_plots = 1, 2, 3
|
295 |
+
if val_sequences is None:
|
296 |
+
val_plot_index = None
|
297 |
+
test_plot_index-= 1
|
298 |
+
total_plots-=1
|
299 |
+
if test_sequences is None:
|
300 |
+
test_plot_index = None
|
301 |
+
total_plots-=1
|
302 |
+
|
303 |
+
if colormap is None: colormap=default_color_map
|
304 |
+
# Create a figure and axes with 1 row and 3 columns
|
305 |
+
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
|
306 |
+
|
307 |
+
# Set axes list
|
308 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
309 |
+
|
310 |
+
log_update('\nMaking histogram of Shannon Entropy distributions')
|
311 |
+
train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences]
|
312 |
+
if not(val_plot_index is None):
|
313 |
+
val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences]
|
314 |
+
if not(test_plot_index is None):
|
315 |
+
test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences]
|
316 |
+
|
317 |
+
for ax in axes_list:
|
318 |
+
ax[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train'])
|
319 |
+
ax[0].set_title(f'Train Set (n={len(train_entropy)})')
|
320 |
+
ax[0].set_xlabel('Shannon Entropy')
|
321 |
+
ax[0].set_ylabel('Frequency')
|
322 |
+
ax[0].grid(True)
|
323 |
+
ax[0].set_axisbelow(True)
|
324 |
+
|
325 |
+
if not(val_plot_index is None):
|
326 |
+
ax[val_plot_index].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val'])
|
327 |
+
ax[val_plot_index].set_title(f'Validation Set (n={len(val_entropy)})')
|
328 |
+
ax[val_plot_index].set_xlabel('Shannon Entropy')
|
329 |
+
ax[val_plot_index].grid(True)
|
330 |
+
ax[val_plot_index].set_axisbelow(True)
|
331 |
+
|
332 |
+
if not(test_plot_index is None):
|
333 |
+
ax[test_plot_index].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test'])
|
334 |
+
ax[test_plot_index].set_title(f'Test Set (n={len(test_entropy)})')
|
335 |
+
ax[test_plot_index].set_xlabel('Shannon Entropy')
|
336 |
+
ax[test_plot_index].grid(True)
|
337 |
+
ax[test_plot_index].set_axisbelow(True)
|
338 |
+
|
339 |
+
fig_individual.set_tight_layout(True)
|
340 |
+
fig_individual.savefig(savepath)
|
341 |
+
log_update(f"\tSaved figure to {savepath}")
|
342 |
+
|
343 |
+
def visualize_splits_aa_composition(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/aa_comp.png',axes=None):
|
344 |
+
set_font()
|
345 |
+
if colormap is None: colormap=default_color_map
|
346 |
+
|
347 |
+
# Create a figure and axes with 1 row and 3 columns
|
348 |
+
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
|
349 |
+
|
350 |
+
# Set axes list
|
351 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
352 |
+
|
353 |
+
log_update('\nMaking bar plot of AA composition across each set')
|
354 |
+
train_comp = calculate_aa_composition(train_sequences)
|
355 |
+
if not(val_sequences is None):
|
356 |
+
val_comp = calculate_aa_composition(val_sequences)
|
357 |
+
if not(test_sequences is None):
|
358 |
+
test_comp = calculate_aa_composition(test_sequences)
|
359 |
+
|
360 |
+
# Create DataFrame
|
361 |
+
if not(val_sequences is None) and not(test_sequences is None):
|
362 |
+
comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T
|
363 |
+
if not(val_sequences is None) and (test_sequences is None):
|
364 |
+
comp_df = pd.DataFrame([train_comp, val_comp], index=['train', 'val']).T
|
365 |
+
if (val_sequences is None) and not(test_sequences is None):
|
366 |
+
comp_df = pd.DataFrame([train_comp, test_comp], index=['train', 'test']).T
|
367 |
+
colors = [colormap[col] for col in comp_df.columns]
|
368 |
+
|
369 |
+
# Plotting
|
370 |
+
for ax in axes_list:
|
371 |
+
comp_df.plot(kind='bar', color=colors, ax=ax)
|
372 |
+
ax.set_title('Amino Acid Composition Across Datasets')
|
373 |
+
ax.set_xlabel('Amino Acid')
|
374 |
+
ax.set_ylabel('Relative Frequency')
|
375 |
+
|
376 |
+
fig_individual.set_tight_layout(True)
|
377 |
+
fig_individual.savefig(savepath)
|
378 |
+
log_update(f"\tSaved figure to {savepath}")
|
379 |
+
|
380 |
+
### Outer methods for visualizing splits
|
381 |
+
def visualize_splits(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None):
|
382 |
+
colormap = {
|
383 |
+
'train': train_color,
|
384 |
+
'val': val_color,
|
385 |
+
'test': test_color
|
386 |
+
}
|
387 |
+
valid_entry = False
|
388 |
+
# Add columns for plotting
|
389 |
+
if not(train_clusters is None) and not(val_clusters is None) and not(test_clusters is None):
|
390 |
+
visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters,benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
|
391 |
+
valid_entry=True
|
392 |
+
if not(train_clusters is None) and (val_clusters is None) and not(test_clusters is None):
|
393 |
+
visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
|
394 |
+
valid_entry=True
|
395 |
+
if not(train_clusters is None) and not(val_clusters is None) and (test_clusters is None):
|
396 |
+
visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
|
397 |
+
valid_entry=True
|
398 |
+
|
399 |
+
if not(valid_entry): raise Exception("Must pass train and at least one of val or test")
|
400 |
+
|
401 |
+
def visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
|
402 |
+
if colormap is None: colormap=default_color_map
|
403 |
+
# Add length column
|
404 |
+
train_clusters['member length'] = train_clusters['member seq'].str.len()
|
405 |
+
val_clusters['member length'] = val_clusters['member seq'].str.len()
|
406 |
+
test_clusters['member length'] = test_clusters['member seq'].str.len()
|
407 |
+
|
408 |
+
# Prepare lengths and seqs for plotting
|
409 |
+
train_lengths = train_clusters['member length'].tolist()
|
410 |
+
val_lengths = val_clusters['member length'].tolist()
|
411 |
+
test_lengths = test_clusters['member length'].tolist()
|
412 |
+
train_sequences = train_clusters['member seq'].tolist()
|
413 |
+
val_sequences = val_clusters['member seq'].tolist()
|
414 |
+
test_sequences = test_clusters['member seq'].tolist()
|
415 |
+
|
416 |
+
# Create a combined figure with 3 rows and 3 columns
|
417 |
+
set_font()
|
418 |
+
fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18))
|
419 |
+
|
420 |
+
# Make the three visualization plots for saving TOGETHER
|
421 |
+
visualize_splits_hist(train_lengths=train_lengths,
|
422 |
+
val_lengths=val_lengths,
|
423 |
+
test_lengths=test_lengths,
|
424 |
+
colormap=colormap, axes=axs[0])
|
425 |
+
visualize_splits_shannon_entropy(train_sequences=train_sequences,
|
426 |
+
val_sequences=val_sequences,
|
427 |
+
test_sequences=test_sequences,
|
428 |
+
colormap=colormap,axes=axs[1])
|
429 |
+
visualize_splits_scatter(train_clusters=train_clusters,
|
430 |
+
val_clusters=val_clusters,
|
431 |
+
test_clusters=test_clusters,
|
432 |
+
benchmark_cluster_reps=benchmark_cluster_reps,
|
433 |
+
colormap=colormap, axes=axs[2, 0])
|
434 |
+
visualize_splits_aa_composition(train_sequences=train_sequences,
|
435 |
+
val_sequences=val_sequences,
|
436 |
+
test_sequences=test_sequences,
|
437 |
+
colormap=colormap, axes=axs[2, 1])
|
438 |
+
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
|
439 |
+
visualize_splits_tsne(train_sequences=train_sequences,
|
440 |
+
val_sequences=val_sequences,
|
441 |
+
test_sequences=test_sequences,
|
442 |
+
colormap=colormap, axes=axs[2, 2])
|
443 |
+
else:
|
444 |
+
# Leave the last subplot blank
|
445 |
+
axs[2, 2].axis('off')
|
446 |
+
|
447 |
+
plt.tight_layout()
|
448 |
+
fig_combined.savefig('splits/combined_plot.png')
|
449 |
+
log_update(f"\nSaved combined figure to splits/combined_plot.png")
|
450 |
+
|
451 |
+
def visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
|
452 |
+
if colormap is None: colormap=default_color_map
|
453 |
+
# Add length column
|
454 |
+
train_clusters['member length'] = train_clusters['member seq'].str.len()
|
455 |
+
test_clusters['member length'] = test_clusters['member seq'].str.len()
|
456 |
+
|
457 |
+
# Prepare lengths and seqs for plotting
|
458 |
+
train_lengths = train_clusters['member length'].tolist()
|
459 |
+
test_lengths = test_clusters['member length'].tolist()
|
460 |
+
train_sequences = train_clusters['member seq'].tolist()
|
461 |
+
test_sequences = test_clusters['member seq'].tolist()
|
462 |
+
|
463 |
+
# Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
|
464 |
+
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
|
465 |
+
set_font()
|
466 |
+
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
|
467 |
+
visualize_splits_tsne(train_sequences=train_sequences,
|
468 |
+
val_sequences=None,
|
469 |
+
test_sequences=test_sequences,
|
470 |
+
colormap=colormap, axes=axs[3, 0])
|
471 |
+
axs[-1,1].axis('off')
|
472 |
+
else:
|
473 |
+
set_font()
|
474 |
+
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
|
475 |
+
|
476 |
+
# Make the three visualization plots for saving TOGETHER
|
477 |
+
visualize_splits_hist(train_lengths=train_lengths,
|
478 |
+
val_lengths=None,
|
479 |
+
test_lengths=test_lengths,
|
480 |
+
colormap=colormap, axes=axs[0])
|
481 |
+
visualize_splits_shannon_entropy(train_sequences=train_sequences,
|
482 |
+
val_sequences=None,
|
483 |
+
test_sequences=test_sequences,
|
484 |
+
colormap=colormap,axes=axs[1])
|
485 |
+
visualize_splits_scatter(train_clusters=train_clusters,
|
486 |
+
val_clusters=None,
|
487 |
+
test_clusters=test_clusters,
|
488 |
+
benchmark_cluster_reps=benchmark_cluster_reps,
|
489 |
+
colormap=colormap, axes=axs[2, 0])
|
490 |
+
visualize_splits_aa_composition(train_sequences=train_sequences,
|
491 |
+
val_sequences=None,
|
492 |
+
test_sequences=test_sequences,
|
493 |
+
colormap=colormap, axes=axs[2, 1])
|
494 |
+
|
495 |
+
plt.tight_layout()
|
496 |
+
fig_combined.savefig('splits/combined_plot.png')
|
497 |
+
log_update(f"\nSaved combined figure to splits/combined_plot.png")
|
498 |
+
|
499 |
+
def visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
|
500 |
+
if colormap is None: colormap=default_color_map
|
501 |
+
# Add length column
|
502 |
+
train_clusters['member length'] = train_clusters['member seq'].str.len()
|
503 |
+
val_clusters['member length'] = val_clusters['member seq'].str.len()
|
504 |
+
|
505 |
+
# Prepare lengths and seqs for plotting
|
506 |
+
train_lengths = train_clusters['member length'].tolist()
|
507 |
+
val_lengths = val_clusters['member length'].tolist()
|
508 |
+
train_sequences = train_clusters['member seq'].tolist()
|
509 |
+
val_sequences = val_clusters['member seq'].tolist()
|
510 |
+
|
511 |
+
# Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
|
512 |
+
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
|
513 |
+
set_font()
|
514 |
+
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
|
515 |
+
visualize_splits_tsne(train_sequences=train_sequences,
|
516 |
+
val_sequences=val_sequences,
|
517 |
+
test_sequences=None,
|
518 |
+
colormap=colormap, axes=axs[3, 0])
|
519 |
+
axs[-1,1].axis('off')
|
520 |
+
else:
|
521 |
+
set_font()
|
522 |
+
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
|
523 |
+
|
524 |
+
# Make the three visualization plots for saving TOGETHER
|
525 |
+
visualize_splits_hist(train_lengths=train_lengths,
|
526 |
+
val_lengths=val_lengths,
|
527 |
+
test_lengths=None,
|
528 |
+
colormap=colormap, axes=axs[0])
|
529 |
+
visualize_splits_shannon_entropy(train_sequences=train_sequences,
|
530 |
+
val_sequences=val_sequences,
|
531 |
+
test_sequences=None,
|
532 |
+
colormap=colormap,axes=axs[1])
|
533 |
+
visualize_splits_scatter(train_clusters=train_clusters,
|
534 |
+
val_clusters=val_clusters,
|
535 |
+
test_clusters=None,
|
536 |
+
benchmark_cluster_reps=benchmark_cluster_reps,
|
537 |
+
colormap=colormap, axes=axs[2, 0])
|
538 |
+
visualize_splits_aa_composition(train_sequences=train_sequences,
|
539 |
+
val_sequences=val_sequences,
|
540 |
+
test_sequences=None,
|
541 |
+
colormap=colormap, axes=axs[2, 1])
|
542 |
+
|
543 |
+
plt.tight_layout()
|
544 |
+
fig_combined.savefig('splits/combined_plot.png')
|
545 |
+
log_update(f"\nSaved combined figure to splits/combined_plot.png")
|