csaguiar commited on
Commit
10ac2fa
1 Parent(s): b2687dc

updating classes used for translation

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -1,17 +1,20 @@
 
1
  import streamlit as st
2
  from diffusers import StableDiffusionPipeline
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5"
6
  TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa
 
7
 
8
 
9
  def load_translation_models(translation_model_id):
10
- tokenizer = AutoTokenizer.from_pretrained(
11
  translation_model_id,
12
  use_auth_token=True
13
  )
14
- text_model = AutoModelForSeq2SeqLM.from_pretrained(
 
15
  translation_model_id,
16
  use_auth_token=True
17
  )
@@ -24,7 +27,7 @@ def pipeline_generate(diffusion_model_id):
24
  diffusion_model_id,
25
  use_auth_token=True
26
  )
27
- pipe = pipe.to("mps")
28
 
29
  # Recommended if your computer has < 64 GB of RAM
30
  pipe.enable_attention_slicing()
@@ -39,7 +42,6 @@ def translate(prompt, tokenizer, text_model):
39
  num_beams=8, early_stopping=True
40
  )
41
  en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True)
42
- print(f"translation: [PT] {prompt} -> [EN] {en_prompt[0]}")
43
 
44
  return en_prompt[0]
45
 
 
1
+ import os
2
  import streamlit as st
3
  from diffusers import StableDiffusionPipeline
4
+ from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
5
 
6
  DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5"
7
  TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa
8
+ DEVICE_NAME = os.getenv("DEVICE_NAME", "cuda")
9
 
10
 
11
  def load_translation_models(translation_model_id):
12
+ tokenizer = MBart50TokenizerFast.from_pretrained(
13
  translation_model_id,
14
  use_auth_token=True
15
  )
16
+ tokenizer.src_lang = 'pt_XX'
17
+ text_model = MBartForConditionalGeneration.from_pretrained(
18
  translation_model_id,
19
  use_auth_token=True
20
  )
 
27
  diffusion_model_id,
28
  use_auth_token=True
29
  )
30
+ pipe = pipe.to(DEVICE_NAME)
31
 
32
  # Recommended if your computer has < 64 GB of RAM
33
  pipe.enable_attention_slicing()
 
42
  num_beams=8, early_stopping=True
43
  )
44
  en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True)
 
45
 
46
  return en_prompt[0]
47