File size: 6,185 Bytes
4fb0503 fda141d 4fb0503 fda141d 26d3d5b fda141d 26d3d5b 4fb0503 fda141d 4fb0503 26d3d5b 4fb0503 c45ba32 4fb0503 b93c8a7 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d b93c8a7 fda141d 4fb0503 b93c8a7 4fb0503 b93c8a7 4fb0503 722d1ff fda141d 4fb0503 fda141d b93c8a7 fda141d b93c8a7 4fb0503 b93c8a7 4fb0503 b93c8a7 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 292a922 b93c8a7 292a922 4fb0503 fda141d 4fb0503 e35b2f0 4fb0503 292a922 4fb0503 fda141d 4fb0503 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.keys import (
CLS_PRED,
ENCODER_INPUTS_ATTENTION_MASK,
ENCODER_INPUTS_STR,
ENCODER_INPUTS_TOKENS,
SCORES,
)
from mammal.model import Mammal
from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
class TcrTask(MammalTask):
def __init__(self, model_dict):
super().__init__(
name="TCRbeta-epitope binding affinity", model_dict=model_dict
)
self.description = "TCRbeta-epitope binding affinity (TCR)"
self.examples = {
"tcr_beta_seq": "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT",
"epitope_seq": "LLQTGIHVRVSQPSL",
}
self.markup_text = """
# Mammal based TCRbeta-epitope binding affinity demonstration
Given a TCR beta chain and epitope amino acid sequences, estimate the binding affinity score.
"""
def generate_prompt(self, tcr_beta_seq, epitope_seq):
prompt = (
"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_EPITOPE><SEQUENCE_NATURAL_START>{epitope_seq}<SEQUENCE_NATURAL_END><EOS>"
)
return prompt
def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
"""convert sample_inputs to sample_dict including creating a proper prompt
Args:
sample_inputs (dict): dictionary containing the inputs to the model
model_holder (MammalObjectBroker): model holder
Returns:
dict: sample_dict for feeding into model
"""
sample_dict = dict()
prompt = self.generate_prompt(**sample_inputs)
sample_dict[ENCODER_INPUTS_STR] = prompt
# Tokenize
sample_dict = model_holder.tokenizer_op(
sample_dict=sample_dict,
key_in=ENCODER_INPUTS_STR,
key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
)
sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
sample_dict[ENCODER_INPUTS_TOKENS]
)
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
)
return sample_dict
def run_model(self, sample_dict, model: Mammal):
# Generate Prediction
batch_dict = model.generate(
[sample_dict],
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=5,
)
return batch_dict
def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
"""
Extract predicted class and scores
"""
positive_token_id = self.positive_token_id(tokenizer_op)
negative_token_id = self.negative_token_id(tokenizer_op)
# negative_token_id = tokenizer_op.get_token_id("<0>")
# positive_token_id = tokenizer_op.get_token_id("<1>")
label_id_to_int = {
negative_token_id: "negative",
positive_token_id: "positive",
}
classification_position = 1
# Get output
generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
decoder_output = batch_dict[CLS_PRED][0]
decoder_output_scores = batch_dict[SCORES][0]
if decoder_output_scores is not None:
scores = decoder_output_scores[classification_position, positive_token_id]
else:
scores = [None]
ans = [
generated_output,
label_id_to_int.get(int(decoder_output[classification_position]), -1),
scores.item(),
]
return ans
def create_and_run_prompt(self, model_name, tcr_beta_seq, epitope_seq):
model_holder = self.model_dict[model_name]
inputs = {
"tcr_beta_seq": tcr_beta_seq,
"epitope_seq": epitope_seq,
}
sample_dict = self.crate_sample_dict(
sample_inputs=inputs, model_holder=model_holder
)
prompt = sample_dict[ENCODER_INPUTS_STR]
batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
res = prompt, *self.decode_output(
batch_dict, tokenizer_op=model_holder.tokenizer_op
)
return res
def create_demo(self, model_name_widget):
with gr.Group() as demo:
gr.Markdown(self.markup_text)
with gr.Row():
tcr_textbox = gr.Textbox(
label="T-cell receptor beta sequence",
# info="standard",
interactive=True,
lines=3,
value=self.examples["tcr_beta_seq"],
)
epitope_textbox = gr.Textbox(
label="Epitope sequence",
# info="standard",
interactive=True,
lines=3,
value=self.examples["epitope_seq"],
)
with gr.Row():
run_mammal = gr.Button(
"Run Mammal prompt for TCL-Epitope Interaction",
variant="primary",
)
with gr.Row():
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
with gr.Row():
decoded = gr.Textbox(label="Mammal output")
predicted_class = gr.Textbox(label="Mammal prediction")
binding_score = gr.Number(label="Binding score")
run_mammal.click(
fn=self.create_and_run_prompt,
inputs=[model_name_widget, tcr_textbox, epitope_textbox],
outputs=[prompt_box, decoded, predicted_class, binding_score],
)
demo.visible = False
return demo
|