ponix-generator / app.py
cwhuh's picture
major changes : upgrade model to v0.2.0
2e49b95
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from llm_wrapper import run_gemini
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import subprocess
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
# PONIX mode load
pipe.load_lora_weights('cwhuh/ponix-generator-v0.2.0', weight_name='pytorch_lora_weights.safetensors')
embedding_path = hf_hub_download(repo_id='cwhuh/ponix-generator-v0.2.0', filename='./ponix-generator-v0.2.0_emb.safetensors', repo_type="model")
state_dict = load_file(embedding_path)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>", "<s2>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
@spaces.GPU(duration=50)
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
print(f"User Prompt: {prompt}")
refined_prompt = run_gemini(
target_prompt=prompt,
prompt_in_path="prompt.json",
)
print(f"Refined Prompt: {refined_prompt}")
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=refined_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type="pil",
good_vae=good_vae,
):
yield img, seed
examples = [
"기계곡학과(λ‘œμΌ“) ν¬λ‹‰μŠ€",
"λ°”μ΄μ˜¬λ¦°μ„ μ—°μ£Όν•˜λŠ” ν¬λ‹‰μŠ€",
"물리학을 μ—°κ΅¬ν•˜λŠ” ν¬λ‹‰μŠ€",
"컴퓨터곡학과 ν¬λ‹‰μŠ€"
]
css="""
#col-container {
margin: 0 auto;
max-width: 580px;
}
.footer {
text-align: center;
margin-top: 20px;
font-size: 0.8em;
color: #666;
}
/* URL 링크 μŠ€νƒ€μΌ */
a {
color: #666 !important;
text-decoration: underline;
}
a:hover {
color: rgb(200, 1, 80) !important;
}
/* κΈ°λ³Έ ν…Œλ§ˆ 색상을 ν¬μŠ€ν… λ ˆλ“œλ‘œ λ³€κ²½ */
:root {
--primary-50: rgb(255, 240, 244);
--primary-100: rgb(255, 200, 220);
--primary-200: rgb(255, 150, 180);
--primary-300: rgb(255, 100, 140);
--primary-400: rgb(255, 50, 100);
--primary-500: rgb(200, 1, 80);
--primary-600: rgb(180, 1, 70);
--primary-700: rgb(160, 1, 60);
--primary-800: rgb(140, 1, 50);
--primary-900: rgb(120, 1, 40);
--primary-950: rgb(100, 1, 30);
}
"""
with gr.Blocks(css=css, theme="soft") as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# πŸ” [POSTECH] PONIX Generator
**[[Github](https://github.com/posplexity/ponix-generator)]** **[[ν”Όλ“œλ°±](https://docs.google.com/forms/d/1BccziUtYGF0ToTjZ8PmxZExJJgzpErCuWmrm6ui0COc/edit)]**
[based on FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
""")
with gr.Group():
gr.Markdown("""
### πŸ” μ‚¬μš© κ°€μ΄λ“œ
- μƒμ„±ν•˜κ³  싢은 이미지λ₯Ό ν•œκΈ€λ‘œ κ°„λ‹¨ν•˜κ²Œ μž‘μ„±ν•΄μ£Όμ„Έμš”.
- μ΄λ―Έμ§€λŠ” λ…Έμ΄μ¦ˆμ—μ„œ 점차적으둜 μƒμ„±λ©λ‹ˆλ‹€. (40~50초 μ†Œμš”)
- λ¬Έμ˜λŠ” μ΄λ©”μΌλ‘œ λΆ€νƒλ“œλ¦½λ‹ˆλ‹€: [email protected]
""")
with gr.Group():
prompt = gr.Text(
label="ν”„λ‘¬ν”„νŠΈ μž…λ ₯",
max_lines=1,
placeholder="μ›ν•˜λŠ” ν¬λ‹‰μŠ€ 이미지λ₯Ό ν•œκΈ€λ‘œ μ„€λͺ…ν•΄μ£Όμ„Έμš”",
container=True,
)
run_button = gr.Button("πŸš€ μƒμ„±ν•˜κΈ°", variant="primary")
result = gr.Image(label="μƒμ„±λœ 이미지")
with gr.Accordion("πŸ› οΈ κ³ κΈ‰ μ„€μ •", open=False):
with gr.Group():
use_prompt_refinement = gr.Checkbox(
label="ν”„λ‘¬ν”„νŠΈ μžλ™ κ°œμ„ ",
value=True,
info="AIκ°€ μž…λ ₯ν•œ ν”„λ‘¬ν”„νŠΈλ₯Ό μžλ™μœΌλ‘œ κ°œμ„ ν•©λ‹ˆλ‹€."
)
with gr.Row():
seed = gr.Slider(
label="μ‹œλ“œ κ°’",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="랜덀 μ‹œλ“œ μ‚¬μš©", value=True)
with gr.Row():
width = gr.Slider(
label="λ„ˆλΉ„",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="높이",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="κ°€μ΄λ˜μŠ€ μŠ€μΌ€μΌ",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="μΆ”λ‘  단계 수",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.Markdown("### μ˜ˆμ‹œ ν”„λ‘¬ν”„νŠΈ")
gr.Examples(
examples = examples,
fn = infer,
inputs = [prompt],
outputs = [result, seed],
cache_examples="lazy"
)
gr.HTML("""
<div class="footer">
PONIX Generator by ν—ˆμ±„μ› | POSTECH
</div>
""")
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs = [result, seed]
)
demo.launch()