import logging import pathlib import gradio as gr import pandas as pd from gt4sd.algorithms.generation.hugging_face import ( HuggingFaceSeq2SeqGenerator, HuggingFaceGenerationAlgorithm, ) from transformers import AutoTokenizer logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) task2prefix = { "forward": "Predict the product of the following reaction: ", "retrosynthesis": "Predict the reaction that produces the following product: ", "paragraph to actions": "Which actions are described in the following paragraph: ", "molecular captioning": "Caption the following smile: ", "text-conditional de novo generation": "Write in SMILES the described molecule: ", } def run_inference( model_name_or_path: str, task: str, prompt: str, num_beams: int, ): instruction = task2prefix[task] config = HuggingFaceSeq2SeqGenerator( algorithm_version=model_name_or_path, prefix=instruction, prompt=prompt, num_beams=num_beams, ) model = HuggingFaceGenerationAlgorithm(config) tokenizer = AutoTokenizer.from_pretrained("t5-small") text = list(model.sample(1))[0] text = text.replace(instruction + prompt, "") text = text.split(tokenizer.eos_token)[0] text = text.replace(tokenizer.pad_token, "") text = text.strip() return text if __name__ == "__main__": models = [ "text-chem-t5-small-standard", "text-chem-t5-small-augm", "text-chem-t5-base-standard", "text-chem-t5-base-augm", ] metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( "" ) print("Examples: ", examples.values.tolist()) with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="Text+chem-T5 model", inputs=[ gr.Dropdown( models, label="Language model", value="text-chem-t5-base-augm", ), gr.Radio( choices=[ "forward", "retrosynthesis", "paragraph to actions", "molecular captioning", "text-conditional de novo generation", ], label="Task", value="paragraph to actions", ), gr.Textbox( label="Text prompt", placeholder="I'm a stochastic parrot.", lines=1, ), gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1), ], outputs=gr.Textbox(label="Output"), article=article, description=description, examples=examples.values.tolist(), ) demo.launch(debug=True, show_error=True)