MolScribe / app.py
thomas0809
add copy button
d30c1c3
raw
history blame contribute delete
No virus
2.23 kB
import gradio as gr
import os
import glob
import cv2
import numpy as np
import torch
from molscribe import MolScribe
from indigo import Indigo
from indigo.renderer import IndigoRenderer
from huggingface_hub import hf_hub_download
REPO_ID = "yujieq/MolScribe"
FILENAME = "swin_base_char_aux_1m.pth"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)
device = torch.device('cpu')
model = MolScribe(ckpt_path, device)
def generate_mol_image(molblock):
indigo = Indigo()
render = IndigoRenderer(indigo)
indigo.setOption('render-output-format', 'png')
indigo.setOption('render-background-color', '1,1,1')
indigo.setOption('render-stereo-style', 'none')
indigo.setOption('render-label-mode', 'hetero')
mol = indigo.loadMolecule(molblock)
buf = render.renderToBuffer(mol)
img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1)
return img
def predict(image):
prediction = model.predict_image(image)
smiles = prediction['smiles']
molfile = prediction['molfile']
image = generate_mol_image(molfile)
return image, smiles, molfile
iface = gr.Interface(
predict,
inputs=gr.Image(label="Upload molecular image", show_label=False).style(height=256),
outputs=[
gr.Image(label="Prediction").style(height=256),
gr.Textbox(label="SMILES").style(show_copy_button=True),
gr.Textbox(label="Molfile").style(show_copy_button=True),
],
allow_flagging="auto",
title="MolScribe",
description="Convert a molecular image into SMILES and Molfile. (It typically takes 2-3 seconds to predict an "
"image, but may take longer if the server is busy. To view the prediction better, copy-paste the "
"Molfile to ChemDraw.) <br> " \
"Paper: [_MolScribe: Robust Molecular Structure Recognition with Image-To-Graph Generation_](https://arxiv.org/abs/2205.14311) <br>" \
"Code: https://github.com/thomas0809/MolScribe <br>" \
"Authors: [Yujie Qian](mailto:[email protected]), Jiang Guo, Zhengkai Tu, Zhening Li, Connor W. Coley, Regina Barzilay. _MIT CSAIL_.",
examples=sorted(glob.glob('examples/*.png')),
examples_per_page=20,
)
iface.launch()