import gradio as gr
import os
import glob
import cv2
import numpy as np
import torch
from rxnscribe import RxnScribe
from huggingface_hub import hf_hub_download
REPO_ID = "yujieq/RxnScribe"
FILENAME = "pix2seq_reaction_full.ckpt"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)
device = torch.device('cpu')
model = RxnScribe(ckpt_path, device)
def get_markdown(reaction):
output = []
for x in ['reactants', 'conditions', 'products']:
s = ''
for ent in reaction[x]:
if 'smiles' in ent:
s += "\n```\n" + ent['smiles'] + "\n```\n"
elif 'text' in ent:
s += ' '.join(ent['text']) + '
'
else:
s += ent['category']
output.append(s)
return output
def predict(image, molscribe, ocr):
predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
pred_image = model.draw_predictions_combined(predictions, image=image)
markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
return pred_image, markdown
with gr.Blocks() as demo:
gr.Markdown("""
RxnScribe
Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram.
The predicted reactions are visualized in separate images.
Red boxes are reactants.
Green boxes are reaction conditions.
Blue boxes are products.
It usually takes 5-10 seconds to process a diagram with this demo.
Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course).
Code: https://github.com/thomas0809/RxnScribe
Authors: [Yujie Qian](mailto:yujieq@csail.mit.edu), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_.
""")
with gr.Column():
with gr.Row():
image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
with gr.Row():
molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
ocr = gr.Checkbox(label="Run OCR to recognize text")
btn = gr.Button("Submit").style(full_width=False)
with gr.Row():
gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto")
markdown = gr.Dataframe(
headers=['#', 'reactant', 'condition', 'product'],
datatype=['number'] + ['markdown'] * 3,
wrap=False
)
btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])
gr.Examples(
examples=sorted(glob.glob('examples/*.png')),
inputs=[image],
outputs=[gallery, markdown],
fn=predict,
)
demo.launch()