test / app.py
SharafeevRavil's picture
Update app.py
f96fa2c verified
import gradio as gr
import torch
from transformers import pipeline
from huggingface_hub import InferenceClient
from PIL import Image, ImageDraw
from gradio_client import Client, handle_file
import numpy as np
import cv2
import os
import tempfile
import io
import base64
import requests
from collections import OrderedDict
import uuid
# Инициализация моделей
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
device = "cuda" if torch.cuda.is_available() else "cpu"
# oneFormer segmentation
oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
# classification = pipeline("image-classification", model="google/vit-base-patch16-224")
# upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
# inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
# Функции для обработки изображений
def segment_image(image):
inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt")
with torch.no_grad():
outputs = oneFormer_model(**inputs)
# post-process the raw predictions
predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# Extract segment ids and masks
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
segments_info = predicted_panoptic_map["segments_info"]
# Create cropped masks
cropped_masks_with_labels = []
label_counts = {}
for segment in segments_info:
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
# cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8)
cropped_image[mask != 0, :3] = np.array(image)[mask != 0]
cropped_image[mask != 0, 3] = 255
label = oneFormer_model.config.id2label[segment["label_id"]]
# Check if label already exists
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc.
cropped_masks_with_labels.append((cropped_image, label))
return cropped_masks_with_labels
def merge_segments_by_labels(gallery_images, labels_input):
labels_to_merge = [label.strip() for label in labels_input.split(";")]
merged_image = None
merged_indices = []
for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path
if label in labels_to_merge:
# Загружаем изображение с помощью PIL, сохраняя альфа-канал
image = Image.open(image_path).convert("RGBA")
if merged_image is None:
merged_image = image.copy()
else:
# Объединяем изображения с учетом альфа-канала
merged_image = Image.alpha_composite(merged_image, image)
merged_indices.append(i)
if merged_image is not None:
# Преобразуем объединенное изображение в numpy array
merged_image_np = np.array(merged_image)
new_gallery_images = [
item for i, item in enumerate(gallery_images) if i not in merged_indices
]
new_name = labels_to_merge[0]
new_gallery_images.append((merged_image_np, new_name))
return new_gallery_images
else:
return gallery_images
def select_segment(segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
return image_path
#Image edit
def return_image(imageEditor):
return imageEditor['composite']
def return_image2(image):
return image
def rembg_client(request: gr.Request):
try:
client = Client("KenjieDec/RemBG", headers={"X-IP-Token": request.headers['x-ip-token']})
print("KenjieDec/RemBG Ip token")
return client
except:
print("KenjieDec/RemBG no token")
return Client("KenjieDec/RemBG")
def autocrop_image(imageEditor, border = 0):
image = imageEditor['composite']
bbox = image.getbbox()
image = image.crop(bbox)
(width, height) = image.size
width += border * 2
height += border * 2
cropped_image = Image.new("RGBA", (width, height), (0,0,0,0))
cropped_image.paste(image, (border, border))
return cropped_image
def remove_black_make_transparent(imageEditor):
image_pil = imageEditor['composite']
if image_pil.mode != "RGBA":
image_pil = image_pil.convert("RGBA")
image_np = np.array(image_pil)
black_pixels_mask = np.all(image_np[:, :, :3] == [0, 0, 0], axis=-1)
image_np[black_pixels_mask, 3] = 0
transparent_image = Image.fromarray(image_np)
return transparent_image
def rembg(imageEditor, rembg_model, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
imageEditor['composite'].save(temp_file.name)
temp_file_path = temp_file.name
client = rembg_client(request)
result = client.predict(
file=handle_file(temp_file_path),
mask="Default",
model=rembg_model,
x=0,
y=0,
api_name="/inference"
)
print(result)
return result
def add_transparent_border(imageEditor, border_size=200):
image = imageEditor['composite']
width, height = image.size
new_width = width + 2 * border_size
new_height = height + 2 * border_size
new_image = Image.new("RGBA", (new_width, new_height), (0, 0, 0, 0))
new_image.paste(image, (border_size, border_size))
return new_image
def upscale(imageEditor, scale, request: gr.Request):
return upscale_image(imageEditor['composite'], version="v1.4", rescaling_factor=scale)
def upscale_image(image_pil, version="v1.4", rescaling_factor=None):
buffered = io.BytesIO()
image_pil.save(buffered, format="PNG") # Save as PNG
img_str = base64.b64encode(buffered.getvalue()).decode()
# Update the data format for PNG
data = {"data": [f"data:image/png;base64,{img_str}", version, rescaling_factor]}
# Send request to the API
response = requests.post("https://nightfury-image-face-upscale-restoration-gfpgan.hf.space/api/predict", json=data)
response.raise_for_status()
# Get the base64 data from the response
base64_data = response.json()["data"][0]
base64_data = base64_data.split(",")[1] # remove data:image/png;base64,
# Convert base64 back to PIL Image
image_bytes = base64.b64decode(base64_data)
upscaled_image = Image.open(io.BytesIO(image_bytes))
return upscaled_image
# def inpainting(source_img, request: gr.Request):
# input_image = source_img["background"].convert("RGB")
# mask_image = source_img["layers"][0].convert("RGB")
# return inpainting_image(imageEditor['composite'])
def inpainting_client(request: gr.Request):
try:
client = Client("ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU", headers={"X-IP-Token": request.headers['x-ip-token']})
print("ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU Ip token")
return client
except:
print("ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU no token")
return Client("ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU")
def inpainting_run(input_image_editor,
prompt,
negative_prompt,
controlnet_conditioning_scale,
guidance_scale,
seed,
num_inference_steps,
true_guidance_scale,
request: gr.Request
):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
# hr_image.save(temp_file.name)
print("inpainting_run")
print(len(input_image_editor["layers"]))
print(input_image_editor["layers"])
print(input_image_editor["layers"][0])
input_image_editor["background"].save(temp_file.name)
temp_file_path = temp_file.name #картинка
print("background", temp_file_path)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file2:
input_image_editor["layers"][0].save(temp_file2.name)
temp_file_path2 = temp_file2.name # маска
print("маска", temp_file_path2)
# client = Client("ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU")
client = inpainting_client(request)
result = client.predict(
input_image_editor={"background":handle_file(temp_file_path),"layers":[handle_file(temp_file_path2)],"composite":None},
prompt=prompt,
negative_prompt=negative_prompt,
controlnet_conditioning_scale=controlnet_conditioning_scale,
guidance_scale=guidance_scale,
seed=seed,
num_inference_steps=num_inference_steps,
true_guidance_scale=true_guidance_scale,
api_name="/process"
)
print(result)
return result
#3d models
def hunyuan_client(request: gr.Request):
try:
client = Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": request.headers['x-ip-token']})
print("tencent/Hunyuan3D-2 Ip token")
return client
except:
print("tencent/Hunyuan3D-2 no token")
return Client("tencent/Hunyuan3D-2")
def vFusion_client(request: gr.Request):
try:
client = Client("facebook/VFusion3D", headers={"X-IP-Token": request.headers['x-ip-token']})
print("facebook/VFusion3D Ip token")
return client
except:
print("facebook/VFusion3D no token")
return Client("facebook/VFusion3D")
def generate_3d_model(image_pil, rembg_Hunyuan, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_pil.save(temp_file.name)
temp_file_path = temp_file.name
client = hunyuan_client(request)
result = client.predict(
caption="",
image=handle_file(temp_file_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=rembg_Hunyuan,
api_name="/shape_generation"
)
print(result)
return result[0]
def generate_3d_model_texture(image_pil, rembg_Hunyuan, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_pil.save(temp_file.name)
temp_file_path = temp_file.name
client = hunyuan_client(request)
result = client.predict(
caption="",
image=handle_file(temp_file_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=rembg_Hunyuan,
api_name="/generation_all"
)
print(result)
return result[1]
def generate_3d_model2(image_pil, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_pil.save(temp_file.name)
temp_file_path = temp_file.name
client = vFusion_client(request)
result = client.predict(
image=handle_file(temp_file_path),
api_name="/step_1_generate_obj"
)
print(result)
return result[0]
### some configs
negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality"
positive_prompt_str = "Full HD, 4K, high quality, high resolution"
########## GRADIO ##########
with gr.Blocks() as demo:
gr.Markdown("# Анализ и редактирование помещений")
with gr.Tab("Сканирование"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
image_input = gr.Image(type="pil", label="Исходное изображение", height = 400)
segment_button = gr.Button("Сегментировать")
with gr.Column(scale=5):
segments_output = gr.Gallery(label="Сегменты изображения")
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")")
merge_segments_button = gr.Button("Соединить сегменты")
merge_segments_button.click(merge_segments_by_labels, inputs=[segments_output, merge_segments_input], outputs=segments_output)
with gr.Row(equal_height=True):
segment_text_input = gr.Textbox(label="Имя сегмента для дальнейшего редактирования")
select_segment_button = gr.Button("Использовать сегмент")
with gr.Tab("Редактирование"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
segment_input = gr.ImageEditor(type="pil", label="Сегмент для редактирования")
with gr.Column(scale=5):
crop_button = gr.Button("Обрезать сегмент")
with gr.Row(equal_height=True):
upscale_slider = gr.Slider(minimum=1, maximum=5, value=2, step=0.1, label="во сколько раз")
upscale_button = gr.Button("Upscale")
with gr.Row(equal_height=True):
rembg_model_selector = gr.Dropdown(
[
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"isnet-anime",
"birefnet-general",
"birefnet-general-lite",
"birefnet-portrait",
"birefnet-dis",
"birefnet-hrsod",
"birefnet-cod",
"birefnet-massive"
],
value="birefnet-general-lite",
label="Rembg model"
)
rembg_button = gr.Button("Rembg")
remove_background_button = gr.Button("Убрать черный задний фон")
with gr.Row(equal_height=True):
add_transparent_border_slider = gr.Slider(minimum=10, maximum=500, value=200, step=10, label="в пикселях")
add_transparent_border_button = gr.Button("Добавить прозрачные края")
use_inpainting_button = gr.Button("Использовать сегмент для Inpainting")
use_button = gr.Button("Использовать сегмент для 3D")
with gr.Tab("Inpainting"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
# inpainting_input = gr.ImageEditor(type="pil", label="Сегмент для Inpainting")
gr.Markdown("У gradio.ImageEditor какой-то странный баг. Если у вас застряла мышка при попытке нарисовать маску - перейдите на 1-й слой. Для маски будет выбран 1-й слой из списка. Для маски используется белый цвет.")
imageMask = gr.ImageEditor(
label='Сегмент для Inpainting',
type='pil',
# sources=["upload", "webcam"],
# image_mode='RGB',
# layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")
)
prompt = gr.Textbox(lines=2, label="Введите промпт для Inpainting", placeholder="Enter prompt here...")
inpainting_button = gr.Button("Inpainting")
with gr.Accordion('Advanced options', open=False):
negative_prompt = gr.Textbox(lines=2, value=negative_prompt_str, label="Negative prompt", placeholder="Enter negative_prompt here...")
controlnet_conditioning_scale = gr.Slider(minimum=0, step=0.01, maximum=1, value=0.9, label="controlnet_conditioning_scale")
guidance_scale = gr.Slider(minimum=1, step=0.5, maximum=10, value=3.5, label="Image to generate")
seed = gr.Slider(minimum=0, step=1, maximum=10000000, value=124, label="Seed Value")
num_inference_steps = gr.Slider(minimum=1, step=1, maximum=30, value=24, label="num_inference_steps")
true_guidance_scale = gr.Slider(minimum=1, step=1, maximum=10, value=3.5, label="true_guidance_scale")
with gr.Column(scale=5):
after_inpainting = gr.Image(type="pil", label="Изображение после Inpainting")
use_inpainting_button2 = gr.Button("Вернуться к редактированию")
use_button2 = gr.Button("Использовать сегмент для 3D")
with gr.Tab("Создание 3D"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
segment_3d_input = gr.Image(type="pil", image_mode="RGBA", label="Сегмент для 3D", height = 600)
rembg_Hunyuan = gr.Checkbox(label="Hunyuan3D-2 rembg Enabled", info="Включить rembg для Hunyuan3D-2?")
hunyuan_button = gr.Button("Hunyuan3D-2 (no texture) [ZeroGPU = 100s]")
hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture) [ZeroGPU = 150s]")
vFusion_button = gr.Button("VFusion3D [если у вас совсем все грустно по ZeroGPU]")
with gr.Column(scale=5):
trellis_output = gr.Model3D(label="3D Model")
#tab1
segment_button.click(segment_image, inputs=image_input, outputs=segments_output)
select_segment_button.click(select_segment, inputs=[segments_output, segment_text_input], outputs=segment_input)
#tab2
crop_button.click(autocrop_image, inputs=segment_input, outputs=segment_input)
upscale_button.click(upscale, inputs=[segment_input, upscale_slider], outputs=segment_input)
rembg_button.click(rembg, inputs=[segment_input, rembg_model_selector], outputs=segment_input)
remove_background_button.click(remove_black_make_transparent, inputs=segment_input, outputs=segment_input)
add_transparent_border_button.click(add_transparent_border, inputs=[segment_input, add_transparent_border_slider], outputs=segment_input)
use_inpainting_button.click(return_image, inputs=segment_input, outputs=imageMask)
use_button.click(return_image, inputs=segment_input, outputs=segment_3d_input)
#tab3
# inpainting_button.click(inpainting, inputs=inpainting_input, outputs=inpainting_input)
inpainting_button.click(
fn=inpainting_run,
inputs=[
imageMask,
prompt,
negative_prompt,
controlnet_conditioning_scale,
guidance_scale,
seed,
num_inference_steps,
true_guidance_scale
],
outputs=after_inpainting
)
use_inpainting_button2.click(return_image2, inputs=after_inpainting, outputs=segment_input)
use_button2.click(return_image2, inputs=after_inpainting, outputs=segment_3d_input)
#3d buttons
hunyuan_button.click(generate_3d_model, inputs=[segment_3d_input, rembg_Hunyuan], outputs=trellis_output)
hunyuan_button_texture.click(generate_3d_model_texture, inputs=[segment_3d_input, rembg_Hunyuan], outputs=trellis_output)
vFusion_button.click(generate_3d_model2, inputs=segment_3d_input, outputs=trellis_output)
demo.launch(debug=True, show_error=True)