import torch import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from umap import UMAP from sklearn.manifold import TSNE from sklearn.decomposition import PCA from transformers import AutoModel, AutoTokenizer path = "/workspace/sg666/MDpLM/benchmarks/Generation" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') esm_model_path = "facebook/esm2_t33_650M_UR50D" # Loads ESM model and tokenizer to embed the sequences def load_esm2_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device) return tokenizer, model def get_latents(model, tokenizer, sequence): inputs = tokenizer(sequence, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().tolist() return embeddings # Load a random set of 100 human and reviewed sequences from uniprot def parse_fasta_file(file_path): with open(file_path, 'r') as file: lines = file.readlines() sequences = [] current_seq = [] current_type = "UniProt" for line in lines: line = line.strip() if line.startswith('>'): if current_seq: sequences.append(("".join(current_seq), current_type)) current_seq = [] else: current_seq.append(line) if current_seq: sequences.append(("".join(current_seq), current_type)) return pd.DataFrame(sequences, columns=["Sequence", "Sequence Source"]).sample(100).reset_index(drop=True) # Obtain/clean sequences generated from ProtGPT2 fine-tuned on membrane sequences protgpt2_sequences = pd.read_csv(path + "/ProtGPT2/protgpt2_generated_sequences.csv") protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('<|ENDOFTEXT|>', '', regex=False) protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('""', '', regex=False) protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('\n', '', regex=False) protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('X', 'G', regex=False) protgpt2_sequences.drop(columns=['Perplexity'], inplace=True) protgpt2_sequences['Sequence Source'] = "ProtGPT2" bad_sequences = [] for seq in protgpt2_sequences['Sequence']: for residue in seq: if residue in ['B', 'U', 'Z', 'O']: bad_sequences.append(seq) protgpt2_sequences = protgpt2_sequences[~protgpt2_sequences['Sequence'].isin(bad_sequences)] # Load MDpLM generated sequences memdlm_sequences = pd.read_csv(path + "/mdlm_de-novo_generation_results.csv") memdlm_sequences.rename(columns={"Generated Sequence": "Sequence"}, inplace=True) memdlm_sequences.drop(columns=['Perplexity'], inplace=True) memdlm_sequences['Sequence Source'] = "MeMDLM" memdlm_sequences.reset_index(drop=True, inplace=True) # Load UniProt sequences # fasta_file_path = path + "/uniprot_human_and_reviewed.fasta" # other_sequences = parse_fasta_file(fasta_file_path) # Load test set sequences other_sequences = pd.read_csv("/workspace/sg666/MDpLM/data/membrane/test.csv") other_sequences['Sequence Source'] = "Test Set" other_sequences = other_sequences.sample(100) # Combine all sequences data = pd.concat([memdlm_sequences, protgpt2_sequences, other_sequences]) # Load ESM model and tokenizer for embeddings tokenizer, model = load_esm2_model(esm_model_path) model = model.to(device) # Embed the sequences data['Embeddings'] = data['Sequence'].apply(lambda sequence: get_latents(model, tokenizer, sequence)) data = data.reset_index(drop=True) umap_df = pd.DataFrame(data['Embeddings'].tolist()) umap_df.index = data['Sequence Source'] # Do PCA umap = UMAP(n_components=2) umap_features = umap.fit_transform(umap_df) umap_df['UMAP1'] = umap_features[:, 0] umap_df['UMAP2'] = umap_features[:, 1] # Visualize the PCA plt.figure(figsize=(8, 5),dpi=300) sns.scatterplot(x='UMAP1', y='UMAP2', hue='Sequence Source', data=umap_df, palette=['#297272', '#ff7477', "#9A77D0"], s=100) plt.xlabel('UMAP1') plt.ylabel('UMAP2') plt.title(f'ESM-650M Embeddings of Membrane Protein Sequences') plt.savefig('esm_umap.png') plt.show()