Spaces:
Paused
Paused
#ref: https://huggingface.co/blog/AmelieSchreiber/esmbind | |
import gradio as gr | |
import os | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
#import wandb | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import pickle | |
import xml.etree.ElementTree as ET | |
from datetime import datetime | |
from sklearn.model_selection import train_test_split | |
from sklearn.utils.class_weight import compute_class_weight | |
from sklearn.metrics import ( | |
accuracy_score, | |
precision_recall_fscore_support, | |
roc_auc_score, | |
matthews_corrcoef | |
) | |
from transformers import ( | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
DataCollatorForTokenClassification, | |
TrainingArguments, | |
Trainer | |
) | |
from peft import PeftModel | |
from datasets import Dataset | |
from accelerate import Accelerator | |
# Imports specific to the custom peft lora model | |
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType | |
from plot_pdb import plot_struc | |
def suggest(option): | |
if option == "Plastic degradation protein": | |
suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ" | |
elif option == "Default protein": | |
#suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE" | |
suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" | |
elif option == "Antifreeze protein": | |
suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH" | |
elif option == "AI Generated protein": | |
suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS" | |
elif option == "7-bladed propeller fold": | |
suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK" | |
else: | |
suggestion = "" | |
return suggestion | |
# Helper Functions and Data Preparation | |
def truncate_labels(labels, max_length): | |
"""Truncate labels to the specified max_length.""" | |
return [label[:max_length] for label in labels] | |
def compute_metrics(p): | |
"""Compute metrics for evaluation.""" | |
predictions, labels = p | |
predictions = np.argmax(predictions, axis=2) | |
# Remove padding (-100 labels) | |
predictions = predictions[labels != -100].flatten() | |
labels = labels[labels != -100].flatten() | |
# Compute accuracy | |
accuracy = accuracy_score(labels, predictions) | |
# Compute precision, recall, F1 score, and AUC | |
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') | |
auc = roc_auc_score(labels, predictions) | |
# Compute MCC | |
mcc = matthews_corrcoef(labels, predictions) | |
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} | |
def compute_loss(model, inputs): | |
"""Custom compute_loss function.""" | |
logits = model(**inputs).logits | |
labels = inputs["labels"] | |
loss_fct = nn.CrossEntropyLoss(weight=class_weights) | |
active_loss = inputs["attention_mask"].view(-1) == 1 | |
active_logits = logits.view(-1, model.config.num_labels) | |
active_labels = torch.where( | |
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) | |
) | |
loss = loss_fct(active_logits, active_labels) | |
return loss | |
# Define Custom Trainer Class | |
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer. | |
class WeightedTrainer(Trainer): | |
def compute_loss(self, model, inputs, return_outputs=False): | |
outputs = model(**inputs) | |
loss = compute_loss(model, inputs) | |
return (loss, outputs) if return_outputs else loss | |
# Predict binding site with finetuned PEFT model | |
def predict_bind(base_model_path,PEFT_model_path,input_seq): | |
# Load the model | |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) | |
loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path) | |
# Ensure the model is in evaluation mode | |
loaded_model.eval() | |
# Tokenization | |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
# Tokenize the sequence | |
inputs = tokenizer(input_seq, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') | |
# Run the model | |
with torch.no_grad(): | |
logits = loaded_model(**inputs).logits | |
# Get predictions | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens | |
predictions = torch.argmax(logits, dim=2) | |
binding_site=[] | |
pos = 0 | |
# Print the predicted labels for each token | |
for token, prediction in zip(tokens, predictions[0].numpy()): | |
if token not in ['<pad>', '<cls>', '<eos>']: | |
pos += 1 | |
print((pos, token, id2label[prediction])) | |
if prediction == 1: | |
print((pos, token, id2label[prediction])) | |
binding_site.append([pos, token, id2label[prediction]]) | |
return binding_site | |
# fine-tuning function | |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset): | |
# Set the LoRA config | |
config = { | |
"lora_alpha": 1, #try 0.5, 1, 2, ..., 16 | |
"lora_dropout": 0.2, | |
"lr": 5.701568055793089e-04, | |
"lr_scheduler_type": "cosine", | |
"max_grad_norm": 0.5, | |
"num_train_epochs": 1, #3, jw 20240628 | |
"per_device_train_batch_size": 12, | |
"r": 2, | |
"weight_decay": 0.2, | |
# Add other hyperparameters as needed | |
} | |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id) | |
# Tokenization | |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D") | |
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) | |
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) | |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) | |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) | |
# Convert the model into a PeftModel | |
peft_config = LoraConfig( | |
task_type=TaskType.TOKEN_CLS, | |
inference_mode=False, | |
r=config["r"], | |
lora_alpha=config["lora_alpha"], | |
target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h" | |
lora_dropout=config["lora_dropout"], | |
bias="none" # or "all" or "lora_only" | |
) | |
base_model = get_peft_model(base_model, peft_config) | |
# Use the accelerator | |
base_model = accelerator.prepare(base_model) | |
train_dataset = accelerator.prepare(train_dataset) | |
test_dataset = accelerator.prepare(test_dataset) | |
model_name_base = base_model_path.split("/")[1] | |
timestamp = datetime.now().strftime('%Y-%m-%d_%H') | |
save_path = f"{model_name_base}-lora-binding-sites_{timestamp}" | |
# Training setup | |
training_args = TrainingArguments( | |
output_dir=save_path, #f"{model_name_base}-lora-binding-sites_{timestamp}", | |
learning_rate=config["lr"], | |
lr_scheduler_type=config["lr_scheduler_type"], | |
gradient_accumulation_steps=1, | |
max_grad_norm=config["max_grad_norm"], | |
per_device_train_batch_size=config["per_device_train_batch_size"], | |
per_device_eval_batch_size=config["per_device_train_batch_size"], | |
num_train_epochs=config["num_train_epochs"], | |
weight_decay=config["weight_decay"], | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
metric_for_best_model="f1", | |
greater_is_better=True, | |
push_to_hub=True, #jw 20240701 False, | |
logging_dir=None, | |
logging_first_step=False, | |
logging_steps=200, | |
save_total_limit=7, | |
no_cuda=False, | |
seed=8893, | |
fp16=True, | |
#report_to='wandb' | |
report_to=None, | |
hub_token = HF_TOKEN, #jw 20240701 | |
) | |
# Initialize Trainer | |
trainer = WeightedTrainer( | |
model=base_model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=test_dataset, | |
tokenizer=tokenizer, | |
data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), | |
compute_metrics=compute_metrics, | |
) | |
# Train and Save Model | |
trainer.train() | |
return save_path | |
# Constants & Globals | |
HF_TOKEN = os.environ.get("HF_token") | |
print("HF_TOKEN:",HF_TOKEN) | |
MODEL_OPTIONS = [ | |
"facebook/esm2_t6_8M_UR50D", | |
"facebook/esm2_t12_35M_UR50D", | |
"facebook/esm2_t33_650M_UR50D", | |
] # models users can choose from | |
PEFT_MODEL_OPTIONS = [ | |
"wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54", | |
"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3", | |
] # finetuned models | |
# Load the data from pickle files (replace with your local paths) | |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f: | |
train_sequences = pickle.load(f) | |
with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f: | |
test_sequences = pickle.load(f) | |
with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f: | |
train_labels = pickle.load(f) | |
with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f: | |
test_labels = pickle.load(f) | |
max_sequence_length = 1000 | |
# Directly truncate the entire list of labels | |
train_labels = truncate_labels(train_labels, max_sequence_length) | |
test_labels = truncate_labels(test_labels, max_sequence_length) | |
# Compute Class Weights | |
classes = [0, 1] | |
flat_train_labels = [label for sublist in train_labels for label in sublist] | |
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels) | |
accelerator = Accelerator() | |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device) | |
# Define labels and model | |
id2label = {0: "No binding site", 1: "Binding site"} | |
label2id = {v: k for k, v in id2label.items()} | |
''' | |
# debug result | |
dubug_result = saved_path #predictions #class_weights | |
''' | |
demo = gr.Blocks(title="DEMO FOR ESM2Bind") | |
with demo: | |
gr.Markdown("# DEMO FOR ESM2Bind") | |
#gr.Textbox(dubug_result) | |
with gr.Column(): | |
gr.Markdown("## Select a base model and a corresponding PEFT finetune model") | |
with gr.Row(): | |
with gr.Column(scale=5, variant="compact"): | |
base_model_name = gr.Dropdown( | |
choices=MODEL_OPTIONS, | |
value=MODEL_OPTIONS[0], | |
label="Base Model Name", | |
interactive = True, | |
) | |
PEFT_model_name = gr.Dropdown( | |
choices=PEFT_MODEL_OPTIONS, | |
value=PEFT_MODEL_OPTIONS[0], | |
label="PEFT Model Name", | |
interactive = True, | |
) | |
with gr.Column(scale=5, variant="compact"): | |
name = gr.Dropdown( | |
label="Choose a Sample Protein", | |
value="Default protein", | |
choices=["Default protein", "Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"] | |
) | |
gr.Markdown( | |
"## Predict binding site and Plot structure for selected protein sequence:" | |
) | |
with gr.Row(): | |
with gr.Column(variant="compact", scale = 8): | |
input_seq = gr.Textbox( | |
lines=1, | |
max_lines=12, | |
label="Protein sequency to be predicted:", | |
value="MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT", | |
placeholder="Paste your protein sequence here...", | |
interactive = True, | |
) | |
text_pos = gr.Textbox( | |
lines=1, | |
max_lines=12, | |
label="Sequency Position:", | |
placeholder= | |
"012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789", | |
interactive=False, | |
) | |
with gr.Column(variant="compact", scale = 2): | |
predict_btn = gr.Button( | |
value="Predict binding site", | |
interactive=True, | |
variant="primary", | |
) | |
plot_struc_btn = gr.Button(value = "Plot ESMFold Predicted Structure ", variant="primary") | |
with gr.Row(): | |
with gr.Column(variant="compact", scale = 5): | |
output_text = gr.Textbox( | |
lines=1, | |
max_lines=12, | |
label="Output", | |
placeholder="Output", | |
) | |
with gr.Column(variant="compact", scale = 5): | |
finetune_button = gr.Button( | |
value="Finetune Pre-trained Model", | |
interactive=True, | |
variant="primary", | |
) | |
with gr.Row(): | |
output_viewer = gr.HTML() | |
output_file = gr.File( | |
label="Download as Text File", | |
file_count="single", | |
type="filepath", | |
interactive=False, | |
) | |
# select protein sample | |
name.change(fn=suggest, inputs=name, outputs=input_seq) | |
# "Predict binding site" actions | |
predict_btn.click( | |
fn = predict_bind, | |
inputs=[base_model_name,PEFT_model_name,input_seq], | |
outputs = [output_text], | |
) | |
# "Finetune Pre-trained Model" actions | |
finetune_button.click( | |
fn = train_function_no_sweeps, | |
inputs=[base_model_name], | |
outputs = [output_text], | |
) | |
# plot protein structure | |
plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer]) | |
demo.launch() |