File size: 3,438 Bytes
41a03fb
 
 
 
 
 
 
 
 
b64cfbe
 
 
 
41a03fb
 
 
 
 
 
b64cfbe
41a03fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from mammal.examples.molnet.molnet_infer import create_sample_dict as molnet_create_sample_dict, get_predictions, process_model_output
from mammal.keys import *
from mammal.model import Mammal

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask


class MolnetTask(MammalTask):
    def __init__(self, model_dict, task_name="BBBP", name=None):
        if name is None:
            name=f"Molnet: {task_name}"
        super().__init__(name=name, model_dict=model_dict)
        self.description = f"MOLNET {task_name}"
        self.examples = {
            "drug_seq": "CC(=O)NCCC1=CNc2c1cc(OC)cc2",
        }
        self.task_name=task_name
        self.markup_text = """
# Mammal demonstration

"""

    def  crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker) -> dict:
        return molnet_create_sample_dict(task_name=self.task_name, smiles_seq=sample_inputs["drug_seq"], tokenizer_op=model_holder.tokenizer_op, model=model_holder.model)

    def run_model(self, sample_dict, model: Mammal):
        # Generate Prediction
        batch_dict = get_predictions(model=model,sample_dict=sample_dict)
        return batch_dict

    def decode_output(self, batch_dict, model_holder):
        result = process_model_output(
            tokenizer_op=model_holder.tokenizer_op,
            decoder_output=batch_dict[CLS_PRED][0],
            decoder_output_scores=batch_dict[SCORES][0],
        )
        generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
        return generated_output, result['pred'], result['score']

    def create_and_run_prompt(self, model_name, drug_seq):
        model_holder = self.model_dict[model_name]
        inputs = {
            "drug_seq": drug_seq,
        }
        sample_dict = self.crate_sample_dict(
            sample_inputs=inputs, model_holder=model_holder
        )
        prompt = sample_dict[ENCODER_INPUTS_STR]
        batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
        res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
        return res

    def create_demo(self, model_name_widget):

        # """
        # ### Using the model from

        # ```{model} ```
        # """
        with gr.Group() as demo:
            gr.Markdown(self.markup_text)
            with gr.Row():
                drug_textbox = gr.Textbox(
                    label="Drug sequance (in SMILES)",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["drug_seq"],
                )
            with gr.Row():
                run_mammal = gr.Button(
                    "Run Mammal prompt for task",
                    variant="primary",
                )
            with gr.Row():
                prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

            with gr.Row():
                decoded = gr.Textbox(label="Mammal output")
                prediction_box=gr.Textbox(label="Mammal prediction")
                score_box=gr.Number(label="score")
                run_mammal.click(
                    fn=self.create_and_run_prompt,
                    inputs=[model_name_widget, drug_textbox],
                    outputs=[prompt_box, decoded, prediction_box, score_box],
                )
            demo.visible = False
            return demo