Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import requests | |
import random | |
import numpy as np | |
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
from huggingface_hub import login | |
import os | |
import time | |
from gradio_imageslider import ImageSlider | |
import requests | |
from io import BytesIO | |
import PIL.Image | |
import requests | |
import shutil | |
import glob | |
from huggingface_hub import snapshot_download, hf_hub_download | |
MAX_SEED = np.iinfo(np.int32).max | |
IMAGE_SIZE = 1024 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if HF_TOKEN: login(token=HF_TOKEN) | |
cp_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') | |
snapshot_download("Djrango/Qwen2vl-Flux", local_dir=cp_dir) | |
hf_hub_download(repo_id="TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline", local_dir=f"{cp_dir}/anyline") | |
shutil.move("checkpoints/anyline/Anyline/MTEED.pth", f"{cp_dir}/anyline") | |
snapshot_download("depth-anything/Depth-Anything-V2-Large", local_dir=f"{cp_dir}/depth-anything-v2") | |
snapshot_download("facebook/sam2-hiera-large", local_dir=f"{cp_dir}/segment-anything-2") | |
# https://github.com/facebookresearch/sam2/issues/26 | |
os.makedirs("sam2_configs", exist_ok=True) | |
for p in glob.glob(f"{cp_dir}/segment-anything-2/*.yaml"): | |
shutil.copy(p, "sam2_configs") | |
from modelmod import FluxModel | |
model = FluxModel(device=DEVICE, is_turbo=False, required_features=['controlnet', 'depth', 'line'], is_quantization=True) # , 'sam' | |
QWEN2VLFLUX_MODES = ["variation", "img2img", "inpaint", "controlnet", "controlnet-inpaint"] | |
QWEN2VLFLUX_ASPECT_RATIO = ["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"] | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time)) | |
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}") | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time)) | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}") | |
def resize_image_dimensions( | |
original_resolution_wh: Tuple[int, int], | |
maximum_dimension: int = IMAGE_SIZE | |
) -> Tuple[int, int]: | |
width, height = original_resolution_wh | |
# if width <= maximum_dimension and height <= maximum_dimension: | |
# width = width - (width % 32) | |
# height = height - (height % 32) | |
# return width, height | |
if width > height: | |
scaling_factor = maximum_dimension / width | |
else: | |
scaling_factor = maximum_dimension / height | |
new_width = int(width * scaling_factor) | |
new_height = int(height * scaling_factor) | |
new_width = new_width - (new_width % 32) | |
new_height = new_height - (new_height % 32) | |
return new_width, new_height | |
def fetch_from_url(url: str, name: str): | |
try: | |
print(f"start to fetch {name} from url", url) | |
response = requests.get(url) | |
response.raise_for_status() | |
image = PIL.Image.open(BytesIO(response.content)) | |
print(f"fetch {name} success") | |
return image | |
except Exception as e: | |
print(e) | |
return None | |
def process( | |
mode: str, | |
input_image_editor: dict, | |
ref_image: Image.Image, | |
image_url: str, | |
mask_url: str, | |
ref_url: str, | |
input_text: str, | |
strength: float, | |
num_inference_steps: int, | |
guidance_scale: float, | |
aspect_ratio: str, | |
attn_mode: bool, | |
center_x: float, | |
center_y: float, | |
radius: float, | |
line_mode: bool, | |
line_strength: float, | |
depth_mode: bool, | |
depth_strength: float, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
#if not input_text: | |
# gr.Info("Please enter a text prompt.") | |
# return None | |
kwargs = {} | |
image = input_image_editor['background'] | |
mask = input_image_editor['layers'][0] | |
if image_url: image = fetch_from_url(image_url, "image") | |
if mask_url: mask = fetch_from_url(mask_url, "mask") | |
if ref_url: ref_image = fetch_from_url(ref_url, "refernce image") | |
if not image: | |
gr.Info("Please upload an image.") | |
return None | |
if ref_image: kwargs["input_image_b"] = ref_image | |
if mode == "inpaint" or mode == "controlnet-inpaint": | |
if not mask: | |
gr.Info("Please draw a mask on the image.") | |
return None | |
kwargs["mask_image"] = mask | |
if attn_mode: | |
kwargs["center_x"] = center_x | |
kwargs["center_y"] = center_y | |
kwargs["radius"] = radius | |
with calculateDuration("run inference"): | |
result = model.generate( | |
input_image_a=image, | |
prompt=input_text, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
aspect_ratio=aspect_ratio, | |
mode=mode, | |
denoise_strength=strength, | |
line_mode=line_mode, | |
line_strength=line_strength, | |
depth_mode=depth_mode, | |
depth_strength=depth_strength, | |
imageCount=1, | |
**kwargs | |
)[0] | |
#return result | |
return [image, result] | |
CSS = """ | |
.title { text-align: center; } | |
""" | |
with gr.Blocks(fill_width=True, css=CSS) as demo: | |
gr.Markdown("# Qwen2VL-Flux", elem_classes="title") | |
with gr.Row(): | |
with gr.Column(): | |
gen_mode = gr.Radio(label="Generation mode", choices=QWEN2VLFLUX_MODES, value="variation") | |
with gr.Row(): | |
input_image_editor = gr.ImageEditor(label='Image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB', | |
layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) | |
ref_image = gr.Image(label='Reference image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB') | |
with gr.Accordion("Image from URL", open=False): | |
image_url = gr.Textbox(label="Image url", show_label=True, max_lines=1, placeholder="Enter your image url (Optional)") | |
mask_url = gr.Textbox(label="Mask image url", show_label=True, max_lines=1, placeholder="Enter your mask image url (Optional)") | |
ref_url = gr.Textbox(label="Reference image url", show_label=True, max_lines=1, placeholder="Enter your reference image url (Optional)") | |
with gr.Accordion("Prompt Settings", open=True): | |
input_text = gr.Textbox(label="Prompt", show_label=True, max_lines=1, placeholder="Enter your prompt") | |
submit_button = gr.Button(value='Submit', variant='primary') | |
with gr.Accordion("Advanced Settings", open=True): | |
with gr.Row(): | |
denoise_strength = gr.Slider(label="Denoise strength", minimum=0, maximum=1, step=0.01, value=0.75) | |
aspect_ratio = gr.Radio(label="Output image ratio", choices=QWEN2VLFLUX_ASPECT_RATIO, value="1:1") | |
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28) | |
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.5, value=3.5) | |
with gr.Accordion("Attention Control", open=True): | |
with gr.Row(): | |
attn_mode = gr.Checkbox(label="Attention Control", value=False) | |
center_x = gr.Slider(label="X coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5) | |
center_y = gr.Slider(label="Y coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5) | |
radius = gr.Slider(label="Radius of attention circle", minimum=0, maximum=1, step=0.01, value=0.5) | |
with gr.Accordion("ControlNet Settings", open=True): | |
with gr.Row(): | |
line_mode = gr.Checkbox(label="Line mode", value=True) | |
line_strength = gr.Slider(label="Line strength", minimum=0, maximum=1, step=0.01, value=0.4) | |
depth_mode = gr.Checkbox(label="Depth mode", value=True) | |
depth_strength = gr.Slider(label="Depth strength", minimum=0, maximum=1, step=0.01, value=0.2) | |
with gr.Column(): | |
#output_image = gr.Image(label="Generated image", type="pil", format="png", show_download_button=True, show_share_button=False) | |
output_image = ImageSlider(label="Generated image", type="pil") | |
gr.on(triggers=[submit_button.click, input_text.submit], fn=process, | |
inputs=[gen_mode, input_image_editor, ref_image, image_url, mask_url, ref_url, | |
input_text, denoise_strength, num_inference_steps, guidance_scale, aspect_ratio, | |
attn_mode, center_x, center_y, radius, line_mode, line_strength, depth_mode, depth_strength], | |
outputs=[output_image], queue=True) | |
demo.queue().launch(debug=True, show_error=True) | |
#demo.queue().launch(debug=True, show_error=True, ssr_mode=False) # Gradio 5 |