import logging import pathlib import pickle import gradio as gr from typing import Dict, Any import pandas as pd from gt4sd.algorithms.generation.diffusion import ( DiffusersGenerationAlgorithm, GeoDiffGenerator, ) from gt4sd.algorithms.registry import ApplicationsRegistry from utils import draw_grid_generate from rdkit import Chem logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def run_inference( algorithm_version: str, prompt_file: str, prompt_id: int, number_of_samples: int, ): # Read file: with open(prompt_file.name, "rb") as f: prompts = pickle.load(f) if all(isinstance(x, str) for x in prompts.keys()): prompt = prompts[prompt_id] else: prompt = prompts config = GeoDiffGenerator( algorithm_version=algorithm_version, prompt=prompt, ) model = DiffusersGenerationAlgorithm(config) results = list(model.sample(number_of_samples)) smiles = [Chem.MolToSmiles(m) for m in results] return draw_grid_generate(samples=smiles, n_cols=5) if __name__ == "__main__": # Preparation (retrieve all available algorithms) all_algos = ApplicationsRegistry.list_available() algos = [ x["algorithm_version"] for x in list( filter(lambda x: "GeoDiff" in x["algorithm_application"], all_algos) ) ] # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = [[algos[0], metadata_root.joinpath("mol_dct.pkl"), 2]] with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="GeoDiff", inputs=[ gr.Dropdown( algos, label="GeoDiff version", value="fusing/gfn-molecule-gen-drugs" ), gr.File(file_types=[".pkl"], label="GeoDiff prompt"), gr.Number(value=0, label="Prompt ID", precision=0), gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1), ], outputs=gr.HTML(label="Output"), article=article, description=description, examples=examples, ) demo.launch(debug=True, show_error=True)