|
import subprocess |
|
|
|
def download_file(url, output_filename): |
|
command = ['wget', '-O', output_filename, '-q', url] |
|
subprocess.run(command, check=True) |
|
|
|
url1 = 'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite' |
|
url2 = 'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_segmenter/float16/latest/selfie_segmenter.tflite' |
|
|
|
filename1 = 'selfie_multiclass_256x256.tflite' |
|
filename2 = 'selfie_segmenter.tflite' |
|
|
|
download_file(url1, filename1) |
|
download_file(url2, filename2) |
|
|
|
import cv2 |
|
import mediapipe as mp |
|
import numpy as np |
|
from mediapipe.tasks import python |
|
from mediapipe.tasks.python import vision |
|
import random |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from diffusers import FluxInpaintPipeline |
|
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL |
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel |
|
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
bfl_repo="black-forest-labs/FLUX.1-dev" |
|
|
|
BG_COLOR = (0, 0, 0) |
|
MASK_COLOR = (255, 255, 255) |
|
|
|
def maskHead(input): |
|
base_options = python.BaseOptions(model_asset_path='selfie_multiclass_256x256.tflite') |
|
options = vision.ImageSegmenterOptions(base_options=base_options, |
|
output_category_mask=True) |
|
|
|
with vision.ImageSegmenter.create_from_options(options) as segmenter: |
|
image = mp.Image.create_from_file(input) |
|
|
|
segmentation_result = segmenter.segment(image) |
|
|
|
hairmask = segmentation_result.confidence_masks[1] |
|
facemask = segmentation_result.confidence_masks[3] |
|
|
|
image_data = image.numpy_view() |
|
fg_image = np.zeros(image_data.shape, dtype=np.uint8) |
|
fg_image[:] = MASK_COLOR |
|
bg_image = np.zeros(image_data.shape, dtype=np.uint8) |
|
bg_image[:] = BG_COLOR |
|
|
|
combined_mask = np.maximum(hairmask.numpy_view(), facemask.numpy_view()) |
|
|
|
condition = np.stack((combined_mask,) * 3, axis=-1) > 0.2 |
|
output_image = np.where(condition, fg_image, bg_image) |
|
|
|
return output_image |
|
|
|
def random_positioning(input, output_size=(1024, 1024)): |
|
if input is None: |
|
raise ValueError("Impossible to load image") |
|
|
|
scale_factor = random.uniform(0.5, 1.0) |
|
|
|
new_size = (int(input.shape[1] * scale_factor), int(input.shape[0] * scale_factor)) |
|
|
|
resized_image = cv2.resize(input, new_size, interpolation=cv2.INTER_AREA) |
|
|
|
background = np.zeros((output_size[1], output_size[0], 3), dtype=np.uint8) |
|
|
|
x_offset = random.randint(0, output_size[0] - new_size[0]) |
|
y_offset = random.randint(0, output_size[1] - new_size[1]) |
|
|
|
background[y_offset:y_offset+new_size[1], x_offset:x_offset+new_size[0]] = resized_image |
|
background = np.clip(background, 0, 255) |
|
background = background.astype(np.uint8) |
|
|
|
return background |
|
|
|
|
|
def remove_background(image_path, mask): |
|
image = cv2.imread(image_path) |
|
inverted_mask = cv2.bitwise_not(mask) |
|
|
|
_, binary_mask = cv2.threshold(inverted_mask, 127, 255, cv2.THRESH_BINARY) |
|
|
|
result = np.zeros_like(image, dtype=np.uint8) |
|
|
|
result[binary_mask == 255] = image[binary_mask == 255] |
|
|
|
return result |
|
|
|
pipe = FluxInpaintPipeline.from_pretrained(bfl_repo, torch_dtype=torch.bfloat16).to(DEVICE) |
|
MAX_SEED = np.iinfo(np.int32).max |
|
TRIGGER = "a photo of TOK" |
|
|
|
@spaces.GPU(duration=200) |
|
def execute(image, prompt): |
|
if not prompt : |
|
gr.Info("Please enter a text prompt.") |
|
return None |
|
|
|
if not image : |
|
gr.Info("Please upload a image.") |
|
return None |
|
|
|
img = cv2.imread(image) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
imgs = [ random_positioning(img), random_positioning(img)] |
|
|
|
pipe.load_lora_weights("XLabs-AI/flux-RealismLora", weight_name='lora.safetensors') |
|
response = [] |
|
|
|
seed_slicer = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed_slicer) |
|
|
|
for image in range(len(imgs)): |
|
current_img = imgs[image] |
|
cv2.imwrite('base_image.jpg', current_img) |
|
mask = maskHead('base_image.jpg') |
|
result = pipe( |
|
prompt=f"{prompt} {TRIGGER}", |
|
image=current_img, |
|
mask_image=mask, |
|
width=1024, |
|
height=1024, |
|
strength=0.85, |
|
generator=generator, |
|
num_inference_steps=28, |
|
max_sequence_length=256, |
|
joint_attention_kwargs={"scale": 0.9}, |
|
).images[0] |
|
response.append(result) |
|
|
|
return response |
|
|
|
iface = gr.Interface( |
|
fn=execute, |
|
inputs=[ |
|
gr.Image(type="filepath"), |
|
gr.Textbox(label="Prompt") |
|
], |
|
outputs="gallery" |
|
) |
|
|
|
iface.launch(share=True, debug=True) |
|
|
|
|