File size: 3,007 Bytes
bf9d0ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torchvision.transforms as transforms

import matplotlib
import matplotlib.pyplot as plt

import gradio as gr

from math2latex.data import Tokenizer
from math2latex.model import ResNetTransformer

# Global variables to hold the setup components
model, tokenizer = None, None

def get_formulas(filename):
    with open(filename, 'r') as f:
        formulas = f.readlines()
    return formulas

def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200):

    # Runtime Configuration Parameters
    matplotlib.rcParams["mathtext.fontset"] = "cm"  # Font changed to Computer Modern
    # matplotlib.rcParams['text.usetex'] = True  # Use LaTeX to write all text

    fig = plt.figure(figsize=image_size_in, dpi=dpi)
    text = fig.text(
        x=0.5,
        y=0.5,
        s=latex_expression,
        horizontalalignment="center",
        verticalalignment="center",
        fontsize=fontsize,
    )

    plt.savefig(image_name)
    plt.close(fig)

def setup():
    global model, tokenizer
    # setup the model
    checkpoint_path = 'checkpoints/model.ckpt'
    model = ResNetTransformer()
    state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model.to("cpu")
    model.eval()

    # # setup the tokenizer
    formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst')
    tokenizer = Tokenizer(formulas)
    

def predict_image(image):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        setup()

    transform = transforms.ToTensor()

    image = transform(image)
    image = image.unsqueeze(0)
    with torch.no_grad():
        output = model.predict(image)
    
    tokens = tokenizer.decode(output[0].tolist())
    return tokens

def predict_and_convert_to_image(image):

    latex_code = predict_image(image)

    image_name = 'temp.png'
    latex_code_modified = latex_code.replace(" ", "")  # Remove spaces from the LaTeX code
    latex_code_modified = rf"""${latex_code_modified}$"""
    latex2image(latex_code_modified, image_name)
    
    # Return both the LaTeX code and the path of the generated image
    return latex_code, image_name

def main():
    setup()
    examples = [
        ["dataset/formula_images_processed/78228211ca.png"],
        ["dataset/formula_images_processed/2b891b21ac.png"],
        ["dataset/formula_images_processed/a8ec0c091c.png"],
    ]
    demo = gr.Interface(
        fn=predict_and_convert_to_image, 
        inputs='image', 
        outputs=['text', 'image'],
        # examples=examples,
        title='Image to LaTeX code',
        description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.'
    )
    demo.launch()

if __name__ == "__main__":
    main()