rvian commited on
Commit
d9956d6
·
1 Parent(s): 4e3e82b
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -22,11 +22,11 @@ def carregar_modelo_e_tokenizador_mbart(modelo):
22
 
23
  # TODO:batch?
24
  def traduzir_en_pt(text):
25
- inputs = tokenizer(text, return_tensors='pt')
26
  input_ids = inputs.input_ids
27
  attention_mask = inputs.attention_mask
28
- output = model.generate(input_ids, attention_mask=attention_mask, forced_bos_token_id=tokenizer.lang_code_to_id['pt_XX'])
29
- return tokenizer.decode(output[0], skip_special_tokens=True)
30
 
31
  ## streamlit ##
32
  def carregar_dataset():
@@ -77,9 +77,9 @@ dataset = carregar_dataset()
77
  if dataset is not None:
78
  mostrar_dataset()
79
  if st.button("Carregar modelo"):
80
- model, tokenizer = carregar_modelo()
81
 
82
 
83
- if st.button("Traduzir dataset") and model is not None:
84
  traduzir_dataset(dataset)
85
  resultado()
 
22
 
23
  # TODO:batch?
24
  def traduzir_en_pt(text):
25
+ inputs = tokenizador(text, return_tensors='pt')
26
  input_ids = inputs.input_ids
27
  attention_mask = inputs.attention_mask
28
+ output = modelo.generate(input_ids, attention_mask=attention_mask, forced_bos_token_id=tokenizador.lang_code_to_id['pt_XX'])
29
+ return tokenizador.decode(output[0], skip_special_tokens=True)
30
 
31
  ## streamlit ##
32
  def carregar_dataset():
 
77
  if dataset is not None:
78
  mostrar_dataset()
79
  if st.button("Carregar modelo"):
80
+ modelo, tokenizador = carregar_modelo()
81
 
82
 
83
+ if st.button("Traduzir dataset") and modelo is not None:
84
  traduzir_dataset(dataset)
85
  resultado()