import gradio as gr import os import glob import cv2 import numpy as np import torch from rxnscribe import RxnScribe from huggingface_hub import hf_hub_download REPO_ID = "yujieq/RxnScribe" FILENAME = "pix2seq_reaction_full.ckpt" ckpt_path = hf_hub_download(REPO_ID, FILENAME) device = torch.device('cpu') model = RxnScribe(ckpt_path, device) def get_markdown(reaction): output = [] for x in ['reactants', 'conditions', 'products']: s = '' for ent in reaction[x]: if 'smiles' in ent: s += "\n```\n" + ent['smiles'] + "\n```\n" elif 'text' in ent: s += ' '.join(ent['text']) + '
' else: s += ent['category'] output.append(s) return output def predict(image, molscribe, ocr): predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr) pred_image = model.draw_predictions_combined(predictions, image=image) markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)] return pred_image, markdown with gr.Blocks() as demo: gr.Markdown("""

RxnScribe

Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram. The predicted reactions are visualized in separate images. Red boxes are reactants. Green boxes are reaction conditions. Blue boxes are products. It usually takes 5-10 seconds to process a diagram with this demo. Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course). Code: https://github.com/thomas0809/RxnScribe Authors: [Yujie Qian](mailto:yujieq@csail.mit.edu), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_. """) with gr.Column(): with gr.Row(): image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256) with gr.Row(): molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures") ocr = gr.Checkbox(label="Run OCR to recognize text") btn = gr.Button("Submit").style(full_width=False) with gr.Row(): gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto") markdown = gr.Dataframe( headers=['#', 'reactant', 'condition', 'product'], datatype=['number'] + ['markdown'] * 3, wrap=False ) btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown]) gr.Examples( examples=sorted(glob.glob('examples/*.png')), inputs=[image], outputs=[gallery, markdown], fn=predict, ) demo.launch()