Spaces:
Paused
Paused
Initial Commit
Browse files- .gitattributes +1 -0
- app.py +102 -0
- checkpoints/model.ckpt +3 -0
- dataset/formula_images_processed/2b891b21ac.png +0 -0
- dataset/formula_images_processed/78228211ca.png +0 -0
- dataset/formula_images_processed/a8ec0c091c.png +0 -0
- dataset/im2latex_formulas.norm.processed.lst +3 -0
- math2latex/data/__init__.py +4 -0
- math2latex/data/tokenizer.py +81 -0
- math2latex/model/__init__.py +1 -0
- math2latex/model/positional_encoding.py +45 -0
- math2latex/model/transformer.py +120 -0
- requirements.txt +13 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.lst filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
|
4 |
+
import matplotlib
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from math2latex.data import Tokenizer
|
10 |
+
from math2latex.model import ResNetTransformer
|
11 |
+
|
12 |
+
# Global variables to hold the setup components
|
13 |
+
model, tokenizer = None, None
|
14 |
+
|
15 |
+
def get_formulas(filename):
|
16 |
+
with open(filename, 'r') as f:
|
17 |
+
formulas = f.readlines()
|
18 |
+
return formulas
|
19 |
+
|
20 |
+
def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200):
|
21 |
+
|
22 |
+
# Runtime Configuration Parameters
|
23 |
+
matplotlib.rcParams["mathtext.fontset"] = "cm" # Font changed to Computer Modern
|
24 |
+
# matplotlib.rcParams['text.usetex'] = True # Use LaTeX to write all text
|
25 |
+
|
26 |
+
fig = plt.figure(figsize=image_size_in, dpi=dpi)
|
27 |
+
text = fig.text(
|
28 |
+
x=0.5,
|
29 |
+
y=0.5,
|
30 |
+
s=latex_expression,
|
31 |
+
horizontalalignment="center",
|
32 |
+
verticalalignment="center",
|
33 |
+
fontsize=fontsize,
|
34 |
+
)
|
35 |
+
|
36 |
+
plt.savefig(image_name)
|
37 |
+
plt.close(fig)
|
38 |
+
|
39 |
+
def setup():
|
40 |
+
global model, tokenizer
|
41 |
+
# setup the model
|
42 |
+
checkpoint_path = 'checkpoints/model.ckpt'
|
43 |
+
model = ResNetTransformer()
|
44 |
+
state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
|
45 |
+
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
|
46 |
+
model.load_state_dict(state_dict)
|
47 |
+
model.to("cpu")
|
48 |
+
model.eval()
|
49 |
+
|
50 |
+
# # setup the tokenizer
|
51 |
+
formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst')
|
52 |
+
tokenizer = Tokenizer(formulas)
|
53 |
+
|
54 |
+
|
55 |
+
def predict_image(image):
|
56 |
+
global model, tokenizer
|
57 |
+
|
58 |
+
if model is None or tokenizer is None:
|
59 |
+
setup()
|
60 |
+
|
61 |
+
transform = transforms.ToTensor()
|
62 |
+
|
63 |
+
image = transform(image)
|
64 |
+
image = image.unsqueeze(0)
|
65 |
+
with torch.no_grad():
|
66 |
+
output = model.predict(image)
|
67 |
+
|
68 |
+
tokens = tokenizer.decode(output[0].tolist())
|
69 |
+
return tokens
|
70 |
+
|
71 |
+
def predict_and_convert_to_image(image):
|
72 |
+
|
73 |
+
latex_code = predict_image(image)
|
74 |
+
|
75 |
+
image_name = 'temp.png'
|
76 |
+
latex_code_modified = latex_code.replace(" ", "") # Remove spaces from the LaTeX code
|
77 |
+
latex_code_modified = rf"""${latex_code_modified}$"""
|
78 |
+
latex2image(latex_code_modified, image_name)
|
79 |
+
|
80 |
+
# Return both the LaTeX code and the path of the generated image
|
81 |
+
return latex_code, image_name
|
82 |
+
|
83 |
+
def main():
|
84 |
+
setup()
|
85 |
+
examples = [
|
86 |
+
["dataset/formula_images_processed/78228211ca.png"],
|
87 |
+
["dataset/formula_images_processed/2b891b21ac.png"],
|
88 |
+
["dataset/formula_images_processed/a8ec0c091c.png"],
|
89 |
+
]
|
90 |
+
demo = gr.Interface(
|
91 |
+
fn=predict_and_convert_to_image,
|
92 |
+
inputs='image',
|
93 |
+
outputs=['text', 'image'],
|
94 |
+
# examples=examples,
|
95 |
+
title='Image to LaTeX code',
|
96 |
+
description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.'
|
97 |
+
)
|
98 |
+
demo.launch()
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
main()
|
102 |
+
|
checkpoints/model.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5a21e6e4249e1f7ade1623a66ec17446ab7f870c10accf378057e3c3a9ed2a4
|
3 |
+
size 42532502
|
dataset/formula_images_processed/2b891b21ac.png
ADDED
![]() |
dataset/formula_images_processed/78228211ca.png
ADDED
![]() |
dataset/formula_images_processed/a8ec0c091c.png
ADDED
![]() |
dataset/im2latex_formulas.norm.processed.lst
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d90bc9a6fc69cfcc7c55690eb3f3e8fe8e821236888270db6a7265aea78ade1
|
3 |
+
size 17711869
|
math2latex/data/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .prepare_data import prepare_data
|
2 |
+
from .utils import get_formulas
|
3 |
+
from .dataset import MathToLatexDataset, get_dataloader
|
4 |
+
from .tokenizer import Tokenizer
|
math2latex/data/tokenizer.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from collections import Counter
|
4 |
+
from torchtext.data.utils import get_tokenizer
|
5 |
+
from torchtext.vocab import vocab
|
6 |
+
|
7 |
+
|
8 |
+
class Tokenizer:
|
9 |
+
def __init__(self, formulas=None, max_len=150):
|
10 |
+
# self.tokenizer = get_tokenizer(None)
|
11 |
+
self.tokenizer = get_tokenizer("basic_english")
|
12 |
+
self.max_len = max_len
|
13 |
+
|
14 |
+
if formulas is not None:
|
15 |
+
self.vocab = self._build_vocab(formulas)
|
16 |
+
self.vocab.set_default_index(self.vocab['<unk>'])
|
17 |
+
self.pad_index = self.vocab['<pad>']
|
18 |
+
self.ignore_indices = {self.vocab['<pad>'], self.vocab['<bos>'], self.vocab['<eos>'], self.vocab['<unk>']}
|
19 |
+
else:
|
20 |
+
self.vocab = None
|
21 |
+
|
22 |
+
def _build_vocab(self, formulas):
|
23 |
+
counter = Counter()
|
24 |
+
for formula in formulas:
|
25 |
+
counter.update(self.tokenizer(formula))
|
26 |
+
return vocab(counter, specials=['<pad>', '<bos>', '<eos>', '<unk>'], min_freq=2)
|
27 |
+
|
28 |
+
def encode(self, formula, with_padding=False):
|
29 |
+
tokens = self.tokenizer(formula)
|
30 |
+
tokens = ['<bos>'] + tokens + ['<eos>']
|
31 |
+
if with_padding:
|
32 |
+
tokens = self.pad(tokens, self.max_len)
|
33 |
+
# add the bos and eos to begining and end of the tokens
|
34 |
+
return [self.vocab[token] for token in tokens]
|
35 |
+
|
36 |
+
def decode(self, indices):
|
37 |
+
return self.vocab.lookup_tokens(list(indices))
|
38 |
+
|
39 |
+
def decode_clean(self, indices):
|
40 |
+
# removes the ignore indices from the decoded tokens
|
41 |
+
cleaned_indices = [index for index in indices if int(index) not in self.ignore_indices]
|
42 |
+
# if self.vocab['<eos>'] in cleaned_indices:
|
43 |
+
# cleaned_indices = cleaned_indices[:cleaned_indices.index(self.vocab['<eos>'])]
|
44 |
+
return self.vocab.lookup_tokens(cleaned_indices)
|
45 |
+
|
46 |
+
def decode_to_string(self, tokens):
|
47 |
+
# returns the decoded tokens as a string
|
48 |
+
decoded = self.decode_clean(tokens)
|
49 |
+
return ' '.join(decoded)
|
50 |
+
|
51 |
+
|
52 |
+
def pad(self, tokens, max_len):
|
53 |
+
if len(tokens) > max_len:
|
54 |
+
tokens = tokens[:max_len]
|
55 |
+
tokens[-1] = '<eos>'
|
56 |
+
return tokens
|
57 |
+
return tokens + ['<pad>'] * (max_len - len(tokens))
|
58 |
+
|
59 |
+
def save_vocab(self, file_path="dataset/tokenizer_vocab.json"):
|
60 |
+
# Save the list of tokens which reflects both `itos` and `stoi`
|
61 |
+
vocab_data = {
|
62 |
+
'itos': self.vocab.get_itos()
|
63 |
+
}
|
64 |
+
with open(file_path, 'w') as f:
|
65 |
+
json.dump(vocab_data, f)
|
66 |
+
|
67 |
+
def load_vocab(self, file_path):
|
68 |
+
with open(file_path, 'r') as f:
|
69 |
+
vocab_data = json.load(f)
|
70 |
+
# Reconstruct the vocabulary from the itos list
|
71 |
+
ordered_tokens = vocab_data['itos']
|
72 |
+
# Reconstruct the counter from the ordered list
|
73 |
+
counter = Counter({token: idx + 1 for idx, token in enumerate(ordered_tokens)}) # idx+1 to ensure non-zero freq
|
74 |
+
self.vocab = vocab(counter, specials=['<pad>', '<bos>', '<eos>', '<unk>'])
|
75 |
+
self.vocab.set_default_index(self.vocab['<unk>'])
|
76 |
+
self.pad_index = self.vocab['<pad>']
|
77 |
+
self.ignore_indices = {self.vocab['<pad>'], self.vocab['<bos>'], self.vocab['<eos>'], self.vocab['<unk>']}
|
78 |
+
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return len(self.vocab)
|
math2latex/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .transformer import ResNetTransformer
|
math2latex/model/positional_encoding.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class PositionalEncoding1D(nn.Module):
|
5 |
+
def __init__(self, d_model, max_len=1000, dropout=0.1):
|
6 |
+
super().__init__()
|
7 |
+
self.d_model = d_model
|
8 |
+
self.max_len = max_len
|
9 |
+
self.dropout = nn.Dropout(p=dropout)
|
10 |
+
|
11 |
+
self.encoding = torch.zeros(max_len, d_model)
|
12 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
13 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
|
14 |
+
self.encoding[:, 0::2] = torch.sin(position * div_term)
|
15 |
+
self.encoding[:, 1::2] = torch.cos(position * div_term)
|
16 |
+
self.encoding = self.encoding.unsqueeze(1)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
self.encoding = self.encoding.to(x.device)
|
20 |
+
x = x + self.encoding[:x.size(0)].detach()
|
21 |
+
return self.dropout(x)
|
22 |
+
|
23 |
+
|
24 |
+
class PositionalEncoding2D(nn.Module):
|
25 |
+
def __init__(self, d_model, max_h=1000, max_w=1000, dropout=0.1):
|
26 |
+
super().__init__()
|
27 |
+
self.d_model = d_model
|
28 |
+
self.max_h = max_h
|
29 |
+
self.max_w = max_w
|
30 |
+
self.dropout = nn.Dropout(p=dropout)
|
31 |
+
|
32 |
+
# create self.encoding considering input x as the shape (B, d_model, H, W)
|
33 |
+
self.encoding = torch.zeros(max_h, max_w, d_model)
|
34 |
+
position_h = torch.arange(0, max_h).unsqueeze(1).float()
|
35 |
+
position_w = torch.arange(0, max_w).unsqueeze(1).float()
|
36 |
+
div_term_h = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
|
37 |
+
div_term_w = torch.exp(torch.arange(1, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
|
38 |
+
self.encoding[:, :, 0::2] = torch.sin(position_h * div_term_h).unsqueeze(1)
|
39 |
+
self.encoding[:, :, 1::2] = torch.cos(position_w * div_term_w).unsqueeze(0)
|
40 |
+
self.encoding = self.encoding.permute(2, 0, 1)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
self.encoding = self.encoding.to(x.device)
|
44 |
+
x = x + self.encoding[:, :x.size(2), :x.size(3)].detach()
|
45 |
+
return self.dropout(x)
|
math2latex/model/transformer.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
from .positional_encoding import PositionalEncoding1D, PositionalEncoding2D
|
7 |
+
|
8 |
+
class ResNetTransformer(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
d_model=128,
|
11 |
+
num_heads=4,
|
12 |
+
num_decoder_layers=3,
|
13 |
+
dim_feedforward=256,
|
14 |
+
dropout=0.1,
|
15 |
+
pos_enc_dropout=0.1,
|
16 |
+
activation='relu',
|
17 |
+
max_len_output=150,
|
18 |
+
num_classes=462):
|
19 |
+
super().__init__()
|
20 |
+
self.d_model = d_model
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.num_decoder_layers = num_decoder_layers
|
23 |
+
self.dim_feedforward = dim_feedforward
|
24 |
+
self.dropout = dropout
|
25 |
+
self.activation = activation
|
26 |
+
self.num_classes = num_classes
|
27 |
+
self.max_len_output = max_len_output
|
28 |
+
|
29 |
+
# Encoder
|
30 |
+
resnet18 = torchvision.models.resnet18(weights=None)
|
31 |
+
|
32 |
+
# Remove the classification head and layer 4 from resnt18 and keep the first 3 layers
|
33 |
+
self.backbone = nn.Sequential(*list(resnet18.children())[:-3])
|
34 |
+
|
35 |
+
self.conv1x1 = nn.Conv2d(256, d_model, kernel_size=1, stride=1, padding=0)
|
36 |
+
self.encoder_pos_enc = PositionalEncoding2D(d_model,
|
37 |
+
max_h=1000,
|
38 |
+
max_w=1000,
|
39 |
+
dropout=pos_enc_dropout) # no images are larger than 1000x1000
|
40 |
+
|
41 |
+
# Decoder
|
42 |
+
self.embedding = nn.Embedding(num_classes, d_model)
|
43 |
+
self.decoder_pos_enc = PositionalEncoding1D(d_model,
|
44 |
+
max_len=max_len_output,
|
45 |
+
dropout=pos_enc_dropout)
|
46 |
+
_transformer_decoder_layer = nn.TransformerDecoderLayer(d_model=d_model,
|
47 |
+
nhead=num_heads,
|
48 |
+
dim_feedforward=dim_feedforward,
|
49 |
+
dropout=dropout,
|
50 |
+
activation=activation)
|
51 |
+
self.transformer_decoder = nn.TransformerDecoder(decoder_layer=_transformer_decoder_layer,
|
52 |
+
num_layers=num_decoder_layers)
|
53 |
+
|
54 |
+
self.linear = nn.Linear(d_model, num_classes)
|
55 |
+
# self.softmax = nn.Softmax(dim=2)
|
56 |
+
|
57 |
+
# get target mask for training
|
58 |
+
self.tgt_mask = self.get_tgt_mask(max_len_output)
|
59 |
+
if self.training:
|
60 |
+
self._init_weights()
|
61 |
+
|
62 |
+
|
63 |
+
def _init_weights(self):
|
64 |
+
self.embedding.weight.data.uniform_(-0.1, 0.1)
|
65 |
+
self.linear.weight.data.uniform_(-0.1, 0.1)
|
66 |
+
self.linear.bias.data.zero_()
|
67 |
+
|
68 |
+
nn.init.kaiming_normal_(self.conv1x1.weight, mode='fan_out', nonlinearity='relu')
|
69 |
+
if self.conv1x1.bias is not None:
|
70 |
+
_, fan_out = nn.init._calculate_fan_in_and_fan_out(self.conv1x1.weight)
|
71 |
+
bound = 1 / torch.sqrt(torch.tensor(fan_out))
|
72 |
+
nn.init.normal_(self.conv1x1.bias, -bound, bound)
|
73 |
+
|
74 |
+
def get_tgt_mask(self, target_size):
|
75 |
+
tgt_mask = torch.triu(torch.ones(target_size, target_size), diagonal=1)
|
76 |
+
tgt_mask = tgt_mask.masked_fill(tgt_mask == 1, float('-inf'))
|
77 |
+
return tgt_mask
|
78 |
+
|
79 |
+
def encode(self, x):
|
80 |
+
# Repeat the input x if it has only 1 channel
|
81 |
+
if x.shape[1] == 1:
|
82 |
+
x = x.repeat(1, 3, 1, 1)
|
83 |
+
x = self.backbone(x)
|
84 |
+
x = self.conv1x1(x)
|
85 |
+
x = self.encoder_pos_enc(x)
|
86 |
+
x = x.flatten(2)
|
87 |
+
x = x.permute(2, 0, 1)
|
88 |
+
return x
|
89 |
+
|
90 |
+
def decode(self, tgt, x):
|
91 |
+
tgt = tgt.permute(1, 0)
|
92 |
+
tgt = self.embedding(tgt)
|
93 |
+
tgt = self.decoder_pos_enc(tgt)
|
94 |
+
tgt_mask = self.tgt_mask[:tgt.size(0), :tgt.size(0)]
|
95 |
+
output = self.transformer_decoder(tgt, x, tgt_mask)
|
96 |
+
output = self.linear(output)
|
97 |
+
return output
|
98 |
+
|
99 |
+
def forward(self, x, tgt):
|
100 |
+
# Encoder
|
101 |
+
x = self.encode(x)
|
102 |
+
|
103 |
+
# Decoder
|
104 |
+
output = self.decode(tgt, x)
|
105 |
+
return output.permute(1, 2, 0)
|
106 |
+
|
107 |
+
|
108 |
+
def predict(self, x):
|
109 |
+
b = x.size(0)
|
110 |
+
x = self.encode(x)
|
111 |
+
|
112 |
+
tgt = torch.zeros((b, self.max_len_output), dtype=torch.long).to(x.device)
|
113 |
+
tgt[:, 0] = 1
|
114 |
+
for t in range(1, self.max_len_output):
|
115 |
+
output = self.decode(tgt[:, :t], x)
|
116 |
+
output = output.argmax(dim=-1)
|
117 |
+
tgt[:, t] = output[-1:]
|
118 |
+
|
119 |
+
return tgt
|
120 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tested with python 3.11.10
|
2 |
+
|
3 |
+
torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
|
4 |
+
torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu
|
5 |
+
torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
|
6 |
+
cuda-python==12.1.0
|
7 |
+
lightning==2.3.3
|
8 |
+
torchmetrics==1.4.2
|
9 |
+
tensorboard==2.18.0
|
10 |
+
matplotlib==3.9.2
|
11 |
+
nltk==3.9.1
|
12 |
+
torchtext==0.18.0
|
13 |
+
gradio==4.44.1
|