|
import gradio as gr |
|
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp |
|
from mammal.examples.protein_solubility.task import ProteinSolubilityTask |
|
from mammal.keys import ( |
|
CLS_PRED, |
|
ENCODER_INPUTS_STR, |
|
SCORES, |
|
) |
|
from mammal.model import Mammal |
|
|
|
from mammal_demo.demo_framework import MammalObjectBroker, MammalTask |
|
|
|
data_preprocessing = ProteinSolubilityTask.data_preprocessing |
|
process_model_output = ProteinSolubilityTask.process_model_output |
|
|
|
|
|
class PsTask(MammalTask): |
|
def __init__(self, model_dict): |
|
super().__init__(name="Protein Solubility", model_dict=model_dict) |
|
self.description = "Protein Solubility (PS)" |
|
self.examples = { |
|
"protein_seq": "LLQTGIHVRVSQPSL", |
|
} |
|
self.markup_text = """ |
|
# Mammal based protein solubility estimation |
|
|
|
Given the protein sequence, estimate if it's water-soluble. |
|
""" |
|
|
|
def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker): |
|
"""convert sample_inputs to sample_dict including creating a proper prompt |
|
|
|
Args: |
|
sample_inputs (dict): dictionary containing the inputs to the model |
|
model_holder (MammalObjectBroker): model holder |
|
Returns: |
|
dict: sample_dict for feeding into model |
|
""" |
|
sample_dict = dict(sample_inputs) |
|
sample_dict = data_preprocessing( |
|
sample_dict=sample_dict, |
|
protein_sequence_key="protein_seq", |
|
tokenizer_op=model_holder.tokenizer_op, |
|
device=model_holder.model.device, |
|
) |
|
|
|
return sample_dict |
|
|
|
def run_model(self, sample_dict, model: Mammal): |
|
|
|
batch_dict = model.generate( |
|
[sample_dict], |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
max_new_tokens=5, |
|
) |
|
return batch_dict |
|
|
|
def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list: |
|
""" |
|
Extract predicted class and scores |
|
""" |
|
ans_dict = process_model_output( |
|
tokenizer_op=tokenizer_op, |
|
decoder_output=batch_dict[CLS_PRED][0], |
|
decoder_output_scores=batch_dict[SCORES][0], |
|
) |
|
ans = [ |
|
tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]), |
|
ans_dict["pred"], |
|
ans_dict["not_normalized_scores"].item(), |
|
ans_dict["normalized_scores"].item(), |
|
] |
|
return ans |
|
|
|
def create_and_run_prompt(self, model_name, protein_seq): |
|
model_holder = self.model_dict[model_name] |
|
sample_inputs = { |
|
"protein_seq": protein_seq, |
|
} |
|
sample_dict = self.crate_sample_dict( |
|
sample_inputs=sample_inputs, model_holder=model_holder |
|
) |
|
prompt = sample_dict[ENCODER_INPUTS_STR] |
|
batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model) |
|
res = prompt, *self.decode_output( |
|
batch_dict, tokenizer_op=model_holder.tokenizer_op |
|
) |
|
return res |
|
|
|
def create_demo(self, model_name_widget): |
|
|
|
with gr.Group() as demo: |
|
gr.Markdown(self.markup_text) |
|
with gr.Row(): |
|
protein_textbox = gr.Textbox( |
|
label="Protein sequence", |
|
|
|
interactive=True, |
|
lines=3, |
|
value=self.examples["protein_seq"], |
|
) |
|
with gr.Row(): |
|
run_mammal = gr.Button( |
|
"Run Mammal prompt for protein solubility estimation", |
|
variant="primary", |
|
) |
|
with gr.Row(): |
|
prompt_box = gr.Textbox(label="Mammal prompt", lines=5) |
|
|
|
with gr.Row(): |
|
decoded = gr.Textbox(label="Mammal output") |
|
predicted_class = gr.Textbox(label="Mammal prediction") |
|
with gr.Column(): |
|
non_norm_score = gr.Number(label="Non normalized score") |
|
norm_score = gr.Number(label="normalized score") |
|
run_mammal.click( |
|
fn=self.create_and_run_prompt, |
|
inputs=[model_name_widget, protein_textbox], |
|
outputs=[ |
|
prompt_box, |
|
decoded, |
|
predicted_class, |
|
non_norm_score, |
|
norm_score, |
|
], |
|
) |
|
demo.visible = False |
|
return demo |
|
|