flxcontrol / app.py
fantos's picture
Update app.py
1e6f880 verified
import sys
sys.path.append('./')
import gradio as gr
import spaces
import os
import sys
import subprocess
import numpy as np
from PIL import Image
import cv2
import torch
import random
from transformers import pipeline
os.system("pip install -e ./controlnet_aux")
from controlnet_aux import OpenposeDetector, CannyDetector
from depth_anything_v2.dpt import DepthAnythingV2
from huggingface_hub import hf_hub_download
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN_GATED")
login(token=hf_token)
MAX_SEED = np.iinfo(np.int32).max
# 번역기 설정
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
def translate_to_english(text):
if any('\uAC00' <= char <= '\uD7A3' for char in text):
return translator(text, max_length=512)[0]['translation_text']
return text
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = 'vitl'
model = DepthAnythingV2(**model_configs[encoder])
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-Large", filename=f"depth_anything_v2_vitl.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet])
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")
mode_mapping = {"캐니":0, "타일":1, "깊이":2, "블러":3, "오픈포즈":4, "그레이스케일":5, "저품질": 6}
strength_mapping = {"캐니":0.65, "타일":0.45, "깊이":0.55, "블러":0.45, "오픈포즈":0.55, "그레이스케일":0.45, "저품질": 0.4}
canny = CannyDetector()
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
torch.backends.cuda.matmul.allow_tf32 = True
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.enable_model_cpu_offload() # for saving memory
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
def extract_depth(image):
image = np.asarray(image)
depth = model.infer_image(image[:, :, ::-1])
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
gray_depth = Image.fromarray(depth).convert('RGB')
return gray_depth
def extract_openpose(img):
processed_image_open_pose = open_pose(img, hand_and_face=True)
return processed_image_open_pose
def extract_canny(image):
processed_image_canny = canny(image)
return processed_image_canny
def apply_gaussian_blur(image, kernel_size=(21, 21)):
image = convert_from_image_to_cv2(image)
blurred_image = convert_from_cv2_to_image(cv2.GaussianBlur(image, kernel_size, 0))
return blurred_image
def convert_to_grayscale(image):
image = convert_from_image_to_cv2(image)
gray_image = convert_from_cv2_to_image(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))
return gray_image
def add_gaussian_noise(image, mean=0, sigma=10):
image = convert_from_image_to_cv2(image)
noise = np.random.normal(mean, sigma, image.shape)
noisy_image = convert_from_cv2_to_image(np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8))
return noisy_image
def tile(input_image, resolution=768):
input_image = convert_from_image_to_cv2(input_image)
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
img = convert_from_cv2_to_image(img)
return img
def resize_img(input_image, max_side=768, min_side=512, size=None,
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
w, h = input_image.size
if size is not None:
w_resize_new, h_resize_new = size
else:
ratio = min_side / min(h, w)
w, h = round(ratio*w), round(ratio*h)
ratio = max_side / max(h, w)
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
if pad_to_max_side:
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
offset_x = (max_side - w_resize_new) // 2
offset_y = (max_side - h_resize_new) // 2
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
input_image = Image.fromarray(res)
return input_image
@spaces.GPU()
def infer(cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed, progress=gr.Progress(track_tqdm=True)):
control_mode_num = mode_mapping[control_mode]
prompt = translate_to_english(prompt)
if cond_in is None:
if image_in is not None:
image_in = resize_img(load_image(image_in))
if control_mode == "Canny":
control_image = extract_canny(image_in)
elif control_mode == "Depth":
control_image = extract_depth(image_in)
elif control_mode == "OpenPose":
control_image = extract_openpose(image_in)
elif control_mode == "Blur":
control_image = apply_gaussian_blur(image_in)
elif control_mode == "LowQuality":
control_image = add_gaussian_noise(image_in)
elif control_mode == "Grayscale":
control_image = convert_to_grayscale(image_in)
elif control_mode == "Tile":
control_image = tile(image_in)
else:
control_image = resize_img(load_image(cond_in))
width, height = control_image.size
image = pipe(
prompt,
control_image=[control_image],
control_mode=[control_mode_num],
width=width,
height=height,
controlnet_conditioning_scale=[control_strength],
num_inference_steps=inference_steps,
guidance_scale=guidance_scale,
generator=torch.manual_seed(seed),
).images[0]
torch.cuda.empty_cache()
return image, control_image, gr.update(visible=True)
css = """
footer {
visibility: hidden;
}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row(equal_height=True):
cond_in = gr.Image(label="Upload Processed Control Image", sources=["upload"], type="filepath")
image_in = gr.Image(label="Extract Condition from Reference Image (Optional)", sources=["upload"], type="filepath")
prompt = gr.Textbox(label="Prompt", value="Highest Quality")
with gr.Accordion("ControlNet"):
control_mode = gr.Radio(
["Canny", "Depth", "OpenPose", "Grayscale", "Blur", "Tile", "LowQuality"],
label="Mode",
value="Grayscale",
info="Select control mode, applies to all images"
)
control_strength = gr.Slider(
label="Control Strength",
minimum=0,
maximum=1.0,
step=0.05,
value=0.50,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Accordion("Advanced Settings", open=False):
with gr.Column():
with gr.Row():
inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=24)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=3.5)
submit_btn = gr.Button("Submit")
with gr.Column():
result = gr.Image(label="Result")
processed_cond = gr.Image(label="Preprocessed Condition")
submit_btn.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False
).then(
fn = infer,
inputs = [cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed],
outputs = [result, processed_cond],
show_api=False
)
demo.queue(api_open=False)
demo.launch()