|
import gradio as gr |
|
import torch |
|
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp |
|
from mammal.keys import ( |
|
CLS_PRED, |
|
ENCODER_INPUTS_ATTENTION_MASK, |
|
ENCODER_INPUTS_STR, |
|
ENCODER_INPUTS_TOKENS, |
|
SCORES, |
|
) |
|
from mammal.model import Mammal |
|
|
|
from mammal_demo.demo_framework import MammalObjectBroker, MammalTask |
|
|
|
|
|
class TcrTask(MammalTask): |
|
def __init__(self, model_dict): |
|
super().__init__( |
|
name="TCRbeta-epitope binding affinity", model_dict=model_dict |
|
) |
|
self.description = "TCRbeta-epitope binding affinity (TCR)" |
|
self.examples = { |
|
"tcr_beta_seq": "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT", |
|
"epitope_seq": "LLQTGIHVRVSQPSL", |
|
} |
|
self.markup_text = """ |
|
# Mammal based TCRbeta-epitope binding affinity demonstration |
|
|
|
Given a TCR beta chain and epitope amino acid sequences, estimate the binding affinity score. |
|
""" |
|
|
|
def generate_prompt(self, tcr_beta_seq, epitope_seq): |
|
prompt = ( |
|
"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>" |
|
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>" |
|
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_EPITOPE><SEQUENCE_NATURAL_START>{epitope_seq}<SEQUENCE_NATURAL_END><EOS>" |
|
) |
|
|
|
return prompt |
|
|
|
def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker): |
|
"""convert sample_inputs to sample_dict including creating a proper prompt |
|
|
|
Args: |
|
sample_inputs (dict): dictionary containing the inputs to the model |
|
model_holder (MammalObjectBroker): model holder |
|
Returns: |
|
dict: sample_dict for feeding into model |
|
""" |
|
sample_dict = dict() |
|
prompt = self.generate_prompt(**sample_inputs) |
|
sample_dict[ENCODER_INPUTS_STR] = prompt |
|
|
|
|
|
sample_dict = model_holder.tokenizer_op( |
|
sample_dict=sample_dict, |
|
key_in=ENCODER_INPUTS_STR, |
|
key_out_tokens_ids=ENCODER_INPUTS_TOKENS, |
|
key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, |
|
) |
|
sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( |
|
sample_dict[ENCODER_INPUTS_TOKENS] |
|
) |
|
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( |
|
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] |
|
) |
|
|
|
return sample_dict |
|
|
|
def run_model(self, sample_dict, model: Mammal): |
|
|
|
batch_dict = model.generate( |
|
[sample_dict], |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
max_new_tokens=5, |
|
) |
|
return batch_dict |
|
|
|
|
|
def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list: |
|
""" |
|
Extract predicted class and scores |
|
""" |
|
|
|
positive_token_id = self.positive_token_id(tokenizer_op) |
|
negative_token_id = self.negative_token_id(tokenizer_op) |
|
|
|
|
|
|
|
|
|
label_id_to_int = { |
|
negative_token_id: "negative", |
|
positive_token_id: "positive", |
|
} |
|
classification_position = 1 |
|
|
|
|
|
generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]) |
|
|
|
decoder_output = batch_dict[CLS_PRED][0] |
|
decoder_output_scores = batch_dict[SCORES][0] |
|
|
|
if decoder_output_scores is not None: |
|
scores = decoder_output_scores[classification_position, positive_token_id] |
|
else: |
|
scores = [None] |
|
|
|
ans = [ |
|
generated_output, |
|
label_id_to_int.get(int(decoder_output[classification_position]), -1), |
|
scores.item(), |
|
] |
|
return ans |
|
|
|
def create_and_run_prompt(self, model_name, tcr_beta_seq, epitope_seq): |
|
model_holder = self.model_dict[model_name] |
|
inputs = { |
|
"tcr_beta_seq": tcr_beta_seq, |
|
"epitope_seq": epitope_seq, |
|
} |
|
sample_dict = self.crate_sample_dict( |
|
sample_inputs=inputs, model_holder=model_holder |
|
) |
|
prompt = sample_dict[ENCODER_INPUTS_STR] |
|
batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model) |
|
res = prompt, *self.decode_output( |
|
batch_dict, tokenizer_op=model_holder.tokenizer_op |
|
) |
|
return res |
|
|
|
def create_demo(self, model_name_widget): |
|
|
|
with gr.Group() as demo: |
|
gr.Markdown(self.markup_text) |
|
with gr.Row(): |
|
tcr_textbox = gr.Textbox( |
|
label="T-cell receptor beta sequence", |
|
|
|
interactive=True, |
|
lines=3, |
|
value=self.examples["tcr_beta_seq"], |
|
) |
|
epitope_textbox = gr.Textbox( |
|
label="Epitope sequence", |
|
|
|
interactive=True, |
|
lines=3, |
|
value=self.examples["epitope_seq"], |
|
) |
|
with gr.Row(): |
|
run_mammal = gr.Button( |
|
"Run Mammal prompt for TCL-Epitope Interaction", |
|
variant="primary", |
|
) |
|
with gr.Row(): |
|
prompt_box = gr.Textbox(label="Mammal prompt", lines=5) |
|
|
|
with gr.Row(): |
|
decoded = gr.Textbox(label="Mammal output") |
|
predicted_class = gr.Textbox(label="Mammal prediction") |
|
binding_score = gr.Number(label="Binding score") |
|
run_mammal.click( |
|
fn=self.create_and_run_prompt, |
|
inputs=[model_name_widget, tcr_textbox, epitope_textbox], |
|
outputs=[prompt_box, decoded, predicted_class, binding_score], |
|
) |
|
demo.visible = False |
|
return demo |
|
|