File size: 2,454 Bytes
10ac2fa
a8d3972
ebf1de3
54e4be1
10ac2fa
ebf1de3
54e4be1
 
a8d3972
 
54e4be1
 
 
a8d3972
 
 
 
10ac2fa
a8d3972
 
 
 
54e4be1
 
 
 
 
a8d3972
 
 
 
10ac2fa
54e4be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import streamlit as st
from diffusers import StableDiffusionPipeline
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration

DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5"
TRANSLATION_MODEL_ID = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation"  # noqa
DEVICE_NAME = os.getenv("DEVICE_NAME", "cpu")
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")


def load_translation_models(translation_model_id):
    tokenizer = MBart50TokenizerFast.from_pretrained(
        translation_model_id,
        use_auth_token=HUGGING_FACE_TOKEN
    )
    tokenizer.src_lang = 'pt_XX'
    text_model = MBartForConditionalGeneration.from_pretrained(
        translation_model_id,
        use_auth_token=HUGGING_FACE_TOKEN
    )

    return tokenizer, text_model


def pipeline_generate(diffusion_model_id):
    pipe = StableDiffusionPipeline.from_pretrained(
        diffusion_model_id, torch_dtype=torch.float16, revision="fp16",
        use_auth_token=HUGGING_FACE_TOKEN
    )
    pipe = pipe.to(DEVICE_NAME)

    # Recommended if your computer has < 64 GB of RAM
    pipe.enable_attention_slicing()

    return pipe


def translate(prompt, tokenizer, text_model):
    pt_tokens = tokenizer([prompt], return_tensors="pt")
    en_tokens = text_model.generate(
        **pt_tokens, max_new_tokens=100,
        num_beams=8, early_stopping=True
    )
    en_prompt = tokenizer.batch_decode(en_tokens, skip_special_tokens=True)

    return en_prompt[0]


def generate_image(pipe, prompt):
    # First-time "warmup" pass (see explanation above)
    _ = pipe(prompt, num_inference_steps=1)

    return pipe(prompt).images[0]


def process_prompt(prompt):
    tokenizer, text_model = load_translation_models(TRANSLATION_MODEL_ID)
    prompt = translate(prompt, tokenizer, text_model)
    pipe = pipeline_generate(DIFFUSION_MODEL_ID)
    image = generate_image(pipe, prompt)
    return image


st.write("# Crie imagens com Stable Diffusion")
prompt_input = st.text_input("Escreva uma descrição da imagem")

placeholder = st.empty()
btn = placeholder.button('Processar imagem', disabled=False, key=1)
reload = st.button('Reiniciar', disabled=False)

if btn:
    placeholder.button('Processar imagem', disabled=True, key=2)
    image = process_prompt(prompt_input)
    st.image(image)
    placeholder.button('Processar imagem', disabled=False, key=3)
    placeholder.empty()

if reload:
    st.experimental_rerun()