Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import streamlit as st | |
from diffusers import StableDiffusionPipeline | |
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration | |
DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation" # noqa | |
DEVICE_NAME = os.getenv("DEVICE_NAME", "cpu") | |
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
def load_translation_models(translation_model_id): | |
tokenizer = MBart50TokenizerFast.from_pretrained( | |
translation_model_id, | |
use_auth_token=HUGGING_FACE_TOKEN | |
) | |
tokenizer.src_lang = 'pt_XX' | |
text_model = MBartForConditionalGeneration.from_pretrained( | |
translation_model_id, | |
use_auth_token=HUGGING_FACE_TOKEN | |
) | |
return tokenizer, text_model | |
def pipeline_generate(diffusion_model_id): | |
pipe = StableDiffusionPipeline.from_pretrained( | |
diffusion_model_id, | |
use_auth_token=HUGGING_FACE_TOKEN | |
) | |
pipe = pipe.to(DEVICE_NAME) | |
# Recommended if your computer has < 64 GB of RAM | |
pipe.enable_attention_slicing() | |
return pipe | |
def translate(prompt, tokenizer, text_model): | |
pt_tokens = tokenizer([prompt], return_tensors="pt") | |
en_tokens = text_model.generate( | |
**pt_tokens, max_new_tokens=100, | |
num_beams=8, early_stopping=True | |
) | |
en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True) | |
return en_prompt[0] | |
def generate_image(pipe, prompt): | |
# First-time "warmup" pass (see explanation above) | |
_ = pipe(prompt, num_inference_steps=1) | |
return pipe(prompt).images[0] | |
def process_prompt(prompt): | |
tokenizer, text_model = load_translation_models(TRANSLATION_MODEL_ID) | |
prompt = translate(prompt, tokenizer, text_model) | |
pipe = pipeline_generate(DIFFUSION_MODEL_ID) | |
image = generate_image(pipe, prompt) | |
return image | |
st.write("# Crie imagens com Stable Diffusion") | |
prompt_input = st.text_input("Escreva uma descrição da imagem") | |
placeholder = st.empty() | |
btn = placeholder.button('Processar imagem', disabled=False, key=1) | |
reload = st.button('Reiniciar', disabled=False) | |
if btn: | |
placeholder.button('Processar imagem', disabled=True, key=2) | |
image = process_prompt(prompt_input) | |
st.image(image) | |
placeholder.button('Processar imagem', disabled=False, key=3) | |
placeholder.empty() | |
if reload: | |
st.experimental_rerun() | |