math2latex / app.py
msdkhairi's picture
Initial Commit
bf9d0ba
raw
history blame
3.01 kB
import torch
import torchvision.transforms as transforms
import matplotlib
import matplotlib.pyplot as plt
import gradio as gr
from math2latex.data import Tokenizer
from math2latex.model import ResNetTransformer
# Global variables to hold the setup components
model, tokenizer = None, None
def get_formulas(filename):
with open(filename, 'r') as f:
formulas = f.readlines()
return formulas
def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200):
# Runtime Configuration Parameters
matplotlib.rcParams["mathtext.fontset"] = "cm" # Font changed to Computer Modern
# matplotlib.rcParams['text.usetex'] = True # Use LaTeX to write all text
fig = plt.figure(figsize=image_size_in, dpi=dpi)
text = fig.text(
x=0.5,
y=0.5,
s=latex_expression,
horizontalalignment="center",
verticalalignment="center",
fontsize=fontsize,
)
plt.savefig(image_name)
plt.close(fig)
def setup():
global model, tokenizer
# setup the model
checkpoint_path = 'checkpoints/model.ckpt'
model = ResNetTransformer()
state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to("cpu")
model.eval()
# # setup the tokenizer
formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst')
tokenizer = Tokenizer(formulas)
def predict_image(image):
global model, tokenizer
if model is None or tokenizer is None:
setup()
transform = transforms.ToTensor()
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
output = model.predict(image)
tokens = tokenizer.decode(output[0].tolist())
return tokens
def predict_and_convert_to_image(image):
latex_code = predict_image(image)
image_name = 'temp.png'
latex_code_modified = latex_code.replace(" ", "") # Remove spaces from the LaTeX code
latex_code_modified = rf"""${latex_code_modified}$"""
latex2image(latex_code_modified, image_name)
# Return both the LaTeX code and the path of the generated image
return latex_code, image_name
def main():
setup()
examples = [
["dataset/formula_images_processed/78228211ca.png"],
["dataset/formula_images_processed/2b891b21ac.png"],
["dataset/formula_images_processed/a8ec0c091c.png"],
]
demo = gr.Interface(
fn=predict_and_convert_to_image,
inputs='image',
outputs=['text', 'image'],
# examples=examples,
title='Image to LaTeX code',
description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.'
)
demo.launch()
if __name__ == "__main__":
main()