ngia's picture
deploy on hugging face spaces for inference
d91ea77
import torch
from tokenizer import CustomTokenizer
from model import Transformer, TransformerConfig
import gradio as gr
# load tokenizers
path_to_src_tokenizer = "trained_tokenizers/vocab_en.json"
path_to_tgt_tokenizer = "trained_tokenizers/vocab_fr.json"
src_tokenizer = CustomTokenizer(path_to_vocab=path_to_src_tokenizer)
tgt_tokenizer = CustomTokenizer(path_to_vocab=path_to_tgt_tokenizer)
#load model
config = TransformerConfig(max_seq_length=512)
model = Transformer(config=config)
path_to_checkpoints = "checkpoints/model.safetensors"
model.load_weights_from_checkpoints(path_to_checkpoints=path_to_checkpoints)
model.eval()
def translate(input_text, skip_special_tokens=True):
src_ids = torch.tensor(src_tokenizer.encode(input_text)).unsqueeze(0)
output_ids = model.inference(src_ids=src_ids, tgt_start_id=tgt_tokenizer.bos_token_id, tgt_end_id=tgt_tokenizer.eos_token_id, max_seq_length=512)
output_tokens = tgt_tokenizer.decode(input=output_ids, skip_special_tokens=skip_special_tokens)
return output_tokens
with gr.Blocks() as demo:
gr.Markdown("## Traduction Anglais → Français")
with gr.Row():
texte_input = gr.Textbox(label="Texte en anglais", lines=4)
texte_output = gr.Textbox(label="Texte traduit (Français)",lines=4, interactive=False)
bouton = gr.Button("Traduire")
bouton.click(translate, inputs=texte_input, outputs=texte_output)
demo.launch()