Spaces:
Sleeping
Sleeping
import torch | |
from tokenizer import CustomTokenizer | |
from model import Transformer, TransformerConfig | |
import gradio as gr | |
# load tokenizers | |
path_to_src_tokenizer = "trained_tokenizers/vocab_en.json" | |
path_to_tgt_tokenizer = "trained_tokenizers/vocab_fr.json" | |
src_tokenizer = CustomTokenizer(path_to_vocab=path_to_src_tokenizer) | |
tgt_tokenizer = CustomTokenizer(path_to_vocab=path_to_tgt_tokenizer) | |
#load model | |
config = TransformerConfig(max_seq_length=512) | |
model = Transformer(config=config) | |
path_to_checkpoints = "checkpoints/model.safetensors" | |
model.load_weights_from_checkpoints(path_to_checkpoints=path_to_checkpoints) | |
model.eval() | |
def translate(input_text, skip_special_tokens=True): | |
src_ids = torch.tensor(src_tokenizer.encode(input_text)).unsqueeze(0) | |
output_ids = model.inference(src_ids=src_ids, tgt_start_id=tgt_tokenizer.bos_token_id, tgt_end_id=tgt_tokenizer.eos_token_id, max_seq_length=512) | |
output_tokens = tgt_tokenizer.decode(input=output_ids, skip_special_tokens=skip_special_tokens) | |
return output_tokens | |
with gr.Blocks() as demo: | |
gr.Markdown("## Traduction Anglais → Français") | |
with gr.Row(): | |
texte_input = gr.Textbox(label="Texte en anglais", lines=4) | |
texte_output = gr.Textbox(label="Texte traduit (Français)",lines=4, interactive=False) | |
bouton = gr.Button("Traduire") | |
bouton.click(translate, inputs=texte_input, outputs=texte_output) | |
demo.launch() | |