File size: 3,074 Bytes
09c907a
 
 
 
a4eba41
 
 
09c907a
 
 
 
 
 
 
 
 
 
 
a4eba41
 
09c907a
a4eba41
09c907a
a4eba41
09c907a
 
a4eba41
 
 
09c907a
a4eba41
 
09c907a
a4eba41
09c907a
bebd86d
09c907a
bebd86d
09c907a
 
 
 
 
 
 
 
b8bc3bf
 
 
09c907a
 
 
 
 
a4eba41
 
 
881b63d
09c907a
 
 
 
 
 
 
 
a4eba41
09c907a
a4eba41
 
 
7d1b3b7
09c907a
a4eba41
09c907a
a4eba41
 
09c907a
 
 
 
 
 
 
 
 
 
a4eba41
09c907a
 
a4eba41
09c907a
 
a4eba41
09c907a
 
 
 
 
343ba2f
09c907a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.controlled_sampling.advanced_manufacturing import (
    CatalystGenerator,
    AdvancedManufacturing,
)
from gt4sd.algorithms.registry import ApplicationsRegistry

from utils import draw_grid_generate

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def run_inference(
    algorithm_version: str,
    target_binding_energy: float,
    primer_smiles: str,
    length: float,
    number_of_points: int,
    number_of_steps: int,
    number_of_samples: int,
):

    config = CatalystGenerator(
        algorithm_version=algorithm_version,
        number_of_points=number_of_points,
        number_of_steps=number_of_steps,
        generated_length=length,
        primer_smiles=primer_smiles,
    )
    model = AdvancedManufacturing(config, target=target_binding_energy)
    samples = list(model.sample(number_of_samples))
    seeds = [] if primer_smiles == "" else [primer_smiles]

    return draw_grid_generate(samples=samples, n_cols=5, seeds=seeds)


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_version"]
        for x in list(
            filter(lambda x: "AdvancedManufact" in x["algorithm_name"], all_algos)
        )
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="Advanced Manufacturing",
        inputs=[
            gr.Dropdown(
                algos,
                label="Algorithm version",
                value="v0",
            ),
            gr.Slider(minimum=1, maximum=100, value=10, label="Target binding energy"),
            gr.Textbox(
                label="Primer SMILES",
                placeholder="FP(F)F.CP(C)c1ccccc1.[Au]",
                lines=1,
            ),
            gr.Slider(
                minimum=5,
                maximum=400,
                value=100,
                label="Maximal sequence length",
                step=1,
            ),
            gr.Slider(
                minimum=16, maximum=128, value=32, label="Number of points", step=1
            ),
            gr.Slider(
                minimum=16, maximum=128, value=50, label="Number of steps", step=1
            ),
            gr.Slider(
                minimum=1, maximum=50, value=10, label="Number of samples", step=1
            ),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        # examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)