msdkhairi commited on
Commit
bf9d0ba
·
1 Parent(s): da1d27e

Initial Commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.lst filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+
7
+ import gradio as gr
8
+
9
+ from math2latex.data import Tokenizer
10
+ from math2latex.model import ResNetTransformer
11
+
12
+ # Global variables to hold the setup components
13
+ model, tokenizer = None, None
14
+
15
+ def get_formulas(filename):
16
+ with open(filename, 'r') as f:
17
+ formulas = f.readlines()
18
+ return formulas
19
+
20
+ def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200):
21
+
22
+ # Runtime Configuration Parameters
23
+ matplotlib.rcParams["mathtext.fontset"] = "cm" # Font changed to Computer Modern
24
+ # matplotlib.rcParams['text.usetex'] = True # Use LaTeX to write all text
25
+
26
+ fig = plt.figure(figsize=image_size_in, dpi=dpi)
27
+ text = fig.text(
28
+ x=0.5,
29
+ y=0.5,
30
+ s=latex_expression,
31
+ horizontalalignment="center",
32
+ verticalalignment="center",
33
+ fontsize=fontsize,
34
+ )
35
+
36
+ plt.savefig(image_name)
37
+ plt.close(fig)
38
+
39
+ def setup():
40
+ global model, tokenizer
41
+ # setup the model
42
+ checkpoint_path = 'checkpoints/model.ckpt'
43
+ model = ResNetTransformer()
44
+ state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
45
+ state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
46
+ model.load_state_dict(state_dict)
47
+ model.to("cpu")
48
+ model.eval()
49
+
50
+ # # setup the tokenizer
51
+ formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst')
52
+ tokenizer = Tokenizer(formulas)
53
+
54
+
55
+ def predict_image(image):
56
+ global model, tokenizer
57
+
58
+ if model is None or tokenizer is None:
59
+ setup()
60
+
61
+ transform = transforms.ToTensor()
62
+
63
+ image = transform(image)
64
+ image = image.unsqueeze(0)
65
+ with torch.no_grad():
66
+ output = model.predict(image)
67
+
68
+ tokens = tokenizer.decode(output[0].tolist())
69
+ return tokens
70
+
71
+ def predict_and_convert_to_image(image):
72
+
73
+ latex_code = predict_image(image)
74
+
75
+ image_name = 'temp.png'
76
+ latex_code_modified = latex_code.replace(" ", "") # Remove spaces from the LaTeX code
77
+ latex_code_modified = rf"""${latex_code_modified}$"""
78
+ latex2image(latex_code_modified, image_name)
79
+
80
+ # Return both the LaTeX code and the path of the generated image
81
+ return latex_code, image_name
82
+
83
+ def main():
84
+ setup()
85
+ examples = [
86
+ ["dataset/formula_images_processed/78228211ca.png"],
87
+ ["dataset/formula_images_processed/2b891b21ac.png"],
88
+ ["dataset/formula_images_processed/a8ec0c091c.png"],
89
+ ]
90
+ demo = gr.Interface(
91
+ fn=predict_and_convert_to_image,
92
+ inputs='image',
93
+ outputs=['text', 'image'],
94
+ # examples=examples,
95
+ title='Image to LaTeX code',
96
+ description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.'
97
+ )
98
+ demo.launch()
99
+
100
+ if __name__ == "__main__":
101
+ main()
102
+
checkpoints/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5a21e6e4249e1f7ade1623a66ec17446ab7f870c10accf378057e3c3a9ed2a4
3
+ size 42532502
dataset/formula_images_processed/2b891b21ac.png ADDED
dataset/formula_images_processed/78228211ca.png ADDED
dataset/formula_images_processed/a8ec0c091c.png ADDED
dataset/im2latex_formulas.norm.processed.lst ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d90bc9a6fc69cfcc7c55690eb3f3e8fe8e821236888270db6a7265aea78ade1
3
+ size 17711869
math2latex/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .prepare_data import prepare_data
2
+ from .utils import get_formulas
3
+ from .dataset import MathToLatexDataset, get_dataloader
4
+ from .tokenizer import Tokenizer
math2latex/data/tokenizer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from collections import Counter
4
+ from torchtext.data.utils import get_tokenizer
5
+ from torchtext.vocab import vocab
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self, formulas=None, max_len=150):
10
+ # self.tokenizer = get_tokenizer(None)
11
+ self.tokenizer = get_tokenizer("basic_english")
12
+ self.max_len = max_len
13
+
14
+ if formulas is not None:
15
+ self.vocab = self._build_vocab(formulas)
16
+ self.vocab.set_default_index(self.vocab['<unk>'])
17
+ self.pad_index = self.vocab['<pad>']
18
+ self.ignore_indices = {self.vocab['<pad>'], self.vocab['<bos>'], self.vocab['<eos>'], self.vocab['<unk>']}
19
+ else:
20
+ self.vocab = None
21
+
22
+ def _build_vocab(self, formulas):
23
+ counter = Counter()
24
+ for formula in formulas:
25
+ counter.update(self.tokenizer(formula))
26
+ return vocab(counter, specials=['<pad>', '<bos>', '<eos>', '<unk>'], min_freq=2)
27
+
28
+ def encode(self, formula, with_padding=False):
29
+ tokens = self.tokenizer(formula)
30
+ tokens = ['<bos>'] + tokens + ['<eos>']
31
+ if with_padding:
32
+ tokens = self.pad(tokens, self.max_len)
33
+ # add the bos and eos to begining and end of the tokens
34
+ return [self.vocab[token] for token in tokens]
35
+
36
+ def decode(self, indices):
37
+ return self.vocab.lookup_tokens(list(indices))
38
+
39
+ def decode_clean(self, indices):
40
+ # removes the ignore indices from the decoded tokens
41
+ cleaned_indices = [index for index in indices if int(index) not in self.ignore_indices]
42
+ # if self.vocab['<eos>'] in cleaned_indices:
43
+ # cleaned_indices = cleaned_indices[:cleaned_indices.index(self.vocab['<eos>'])]
44
+ return self.vocab.lookup_tokens(cleaned_indices)
45
+
46
+ def decode_to_string(self, tokens):
47
+ # returns the decoded tokens as a string
48
+ decoded = self.decode_clean(tokens)
49
+ return ' '.join(decoded)
50
+
51
+
52
+ def pad(self, tokens, max_len):
53
+ if len(tokens) > max_len:
54
+ tokens = tokens[:max_len]
55
+ tokens[-1] = '<eos>'
56
+ return tokens
57
+ return tokens + ['<pad>'] * (max_len - len(tokens))
58
+
59
+ def save_vocab(self, file_path="dataset/tokenizer_vocab.json"):
60
+ # Save the list of tokens which reflects both `itos` and `stoi`
61
+ vocab_data = {
62
+ 'itos': self.vocab.get_itos()
63
+ }
64
+ with open(file_path, 'w') as f:
65
+ json.dump(vocab_data, f)
66
+
67
+ def load_vocab(self, file_path):
68
+ with open(file_path, 'r') as f:
69
+ vocab_data = json.load(f)
70
+ # Reconstruct the vocabulary from the itos list
71
+ ordered_tokens = vocab_data['itos']
72
+ # Reconstruct the counter from the ordered list
73
+ counter = Counter({token: idx + 1 for idx, token in enumerate(ordered_tokens)}) # idx+1 to ensure non-zero freq
74
+ self.vocab = vocab(counter, specials=['<pad>', '<bos>', '<eos>', '<unk>'])
75
+ self.vocab.set_default_index(self.vocab['<unk>'])
76
+ self.pad_index = self.vocab['<pad>']
77
+ self.ignore_indices = {self.vocab['<pad>'], self.vocab['<bos>'], self.vocab['<eos>'], self.vocab['<unk>']}
78
+
79
+
80
+ def __len__(self):
81
+ return len(self.vocab)
math2latex/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .transformer import ResNetTransformer
math2latex/model/positional_encoding.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class PositionalEncoding1D(nn.Module):
5
+ def __init__(self, d_model, max_len=1000, dropout=0.1):
6
+ super().__init__()
7
+ self.d_model = d_model
8
+ self.max_len = max_len
9
+ self.dropout = nn.Dropout(p=dropout)
10
+
11
+ self.encoding = torch.zeros(max_len, d_model)
12
+ position = torch.arange(0, max_len).unsqueeze(1).float()
13
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
14
+ self.encoding[:, 0::2] = torch.sin(position * div_term)
15
+ self.encoding[:, 1::2] = torch.cos(position * div_term)
16
+ self.encoding = self.encoding.unsqueeze(1)
17
+
18
+ def forward(self, x):
19
+ self.encoding = self.encoding.to(x.device)
20
+ x = x + self.encoding[:x.size(0)].detach()
21
+ return self.dropout(x)
22
+
23
+
24
+ class PositionalEncoding2D(nn.Module):
25
+ def __init__(self, d_model, max_h=1000, max_w=1000, dropout=0.1):
26
+ super().__init__()
27
+ self.d_model = d_model
28
+ self.max_h = max_h
29
+ self.max_w = max_w
30
+ self.dropout = nn.Dropout(p=dropout)
31
+
32
+ # create self.encoding considering input x as the shape (B, d_model, H, W)
33
+ self.encoding = torch.zeros(max_h, max_w, d_model)
34
+ position_h = torch.arange(0, max_h).unsqueeze(1).float()
35
+ position_w = torch.arange(0, max_w).unsqueeze(1).float()
36
+ div_term_h = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
37
+ div_term_w = torch.exp(torch.arange(1, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
38
+ self.encoding[:, :, 0::2] = torch.sin(position_h * div_term_h).unsqueeze(1)
39
+ self.encoding[:, :, 1::2] = torch.cos(position_w * div_term_w).unsqueeze(0)
40
+ self.encoding = self.encoding.permute(2, 0, 1)
41
+
42
+ def forward(self, x):
43
+ self.encoding = self.encoding.to(x.device)
44
+ x = x + self.encoding[:, :x.size(2), :x.size(3)].detach()
45
+ return self.dropout(x)
math2latex/model/transformer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import torchvision
5
+
6
+ from .positional_encoding import PositionalEncoding1D, PositionalEncoding2D
7
+
8
+ class ResNetTransformer(nn.Module):
9
+ def __init__(self,
10
+ d_model=128,
11
+ num_heads=4,
12
+ num_decoder_layers=3,
13
+ dim_feedforward=256,
14
+ dropout=0.1,
15
+ pos_enc_dropout=0.1,
16
+ activation='relu',
17
+ max_len_output=150,
18
+ num_classes=462):
19
+ super().__init__()
20
+ self.d_model = d_model
21
+ self.num_heads = num_heads
22
+ self.num_decoder_layers = num_decoder_layers
23
+ self.dim_feedforward = dim_feedforward
24
+ self.dropout = dropout
25
+ self.activation = activation
26
+ self.num_classes = num_classes
27
+ self.max_len_output = max_len_output
28
+
29
+ # Encoder
30
+ resnet18 = torchvision.models.resnet18(weights=None)
31
+
32
+ # Remove the classification head and layer 4 from resnt18 and keep the first 3 layers
33
+ self.backbone = nn.Sequential(*list(resnet18.children())[:-3])
34
+
35
+ self.conv1x1 = nn.Conv2d(256, d_model, kernel_size=1, stride=1, padding=0)
36
+ self.encoder_pos_enc = PositionalEncoding2D(d_model,
37
+ max_h=1000,
38
+ max_w=1000,
39
+ dropout=pos_enc_dropout) # no images are larger than 1000x1000
40
+
41
+ # Decoder
42
+ self.embedding = nn.Embedding(num_classes, d_model)
43
+ self.decoder_pos_enc = PositionalEncoding1D(d_model,
44
+ max_len=max_len_output,
45
+ dropout=pos_enc_dropout)
46
+ _transformer_decoder_layer = nn.TransformerDecoderLayer(d_model=d_model,
47
+ nhead=num_heads,
48
+ dim_feedforward=dim_feedforward,
49
+ dropout=dropout,
50
+ activation=activation)
51
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer=_transformer_decoder_layer,
52
+ num_layers=num_decoder_layers)
53
+
54
+ self.linear = nn.Linear(d_model, num_classes)
55
+ # self.softmax = nn.Softmax(dim=2)
56
+
57
+ # get target mask for training
58
+ self.tgt_mask = self.get_tgt_mask(max_len_output)
59
+ if self.training:
60
+ self._init_weights()
61
+
62
+
63
+ def _init_weights(self):
64
+ self.embedding.weight.data.uniform_(-0.1, 0.1)
65
+ self.linear.weight.data.uniform_(-0.1, 0.1)
66
+ self.linear.bias.data.zero_()
67
+
68
+ nn.init.kaiming_normal_(self.conv1x1.weight, mode='fan_out', nonlinearity='relu')
69
+ if self.conv1x1.bias is not None:
70
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.conv1x1.weight)
71
+ bound = 1 / torch.sqrt(torch.tensor(fan_out))
72
+ nn.init.normal_(self.conv1x1.bias, -bound, bound)
73
+
74
+ def get_tgt_mask(self, target_size):
75
+ tgt_mask = torch.triu(torch.ones(target_size, target_size), diagonal=1)
76
+ tgt_mask = tgt_mask.masked_fill(tgt_mask == 1, float('-inf'))
77
+ return tgt_mask
78
+
79
+ def encode(self, x):
80
+ # Repeat the input x if it has only 1 channel
81
+ if x.shape[1] == 1:
82
+ x = x.repeat(1, 3, 1, 1)
83
+ x = self.backbone(x)
84
+ x = self.conv1x1(x)
85
+ x = self.encoder_pos_enc(x)
86
+ x = x.flatten(2)
87
+ x = x.permute(2, 0, 1)
88
+ return x
89
+
90
+ def decode(self, tgt, x):
91
+ tgt = tgt.permute(1, 0)
92
+ tgt = self.embedding(tgt)
93
+ tgt = self.decoder_pos_enc(tgt)
94
+ tgt_mask = self.tgt_mask[:tgt.size(0), :tgt.size(0)]
95
+ output = self.transformer_decoder(tgt, x, tgt_mask)
96
+ output = self.linear(output)
97
+ return output
98
+
99
+ def forward(self, x, tgt):
100
+ # Encoder
101
+ x = self.encode(x)
102
+
103
+ # Decoder
104
+ output = self.decode(tgt, x)
105
+ return output.permute(1, 2, 0)
106
+
107
+
108
+ def predict(self, x):
109
+ b = x.size(0)
110
+ x = self.encode(x)
111
+
112
+ tgt = torch.zeros((b, self.max_len_output), dtype=torch.long).to(x.device)
113
+ tgt[:, 0] = 1
114
+ for t in range(1, self.max_len_output):
115
+ output = self.decode(tgt[:, :t], x)
116
+ output = output.argmax(dim=-1)
117
+ tgt[:, t] = output[-1:]
118
+
119
+ return tgt
120
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tested with python 3.11.10
2
+
3
+ torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
4
+ torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu
5
+ torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
6
+ cuda-python==12.1.0
7
+ lightning==2.3.3
8
+ torchmetrics==1.4.2
9
+ tensorboard==2.18.0
10
+ matplotlib==3.9.2
11
+ nltk==3.9.1
12
+ torchtext==0.18.0
13
+ gradio==4.44.1