File size: 3,947 Bytes
bc088da a3f4f47 bc088da a9c1f71 a3f4f47 bc088da df75ec5 bc088da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import numpy as np
import mediapipe as mp
import uuid
import os
from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation
from croper import Croper
segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=segment_model)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
def restore_result(croper, category, generated_image):
square_length = croper.square_length
generated_image = generated_image.resize((square_length, square_length))
cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
restored_image = croper.input_image.copy()
restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
extension = 'png'
# if restored_image.mode == 'RGBA':
# extension = 'png'
# else:
# extension = 'jpg'
if not os.path.exists("output"):
os.makedirs("output")
path = f"output/{uuid.uuid4()}.{extension}"
restored_image.save(path, quality=100)
return restored_image, path
def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
mask_size = int(input_size)
mask_expansion = int(mask_expansion)
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask_np = category_mask.numpy_view()
if category == "hair":
target_mask = get_hair_mask(category_mask_np, mask_dilation)
elif category == "clothes":
target_mask = get_clothes_mask(category_mask_np, mask_dilation)
elif category == "face":
target_mask = get_face_mask(category_mask_np, mask_dilation)
else:
target_mask = get_face_mask(category_mask_np, mask_dilation)
croper = Croper(input_image, target_mask, mask_size, mask_expansion)
croper.corp_mask_image()
origin_area_image = croper.resized_square_image
return origin_area_image, croper
def get_face_mask(category_mask_np, dilation=1):
face_skin_mask = category_mask_np == 3
if dilation > 0:
face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
return face_skin_mask
def get_clothes_mask(category_mask_np, dilation=1):
body_skin_mask = category_mask_np == 2
clothes_mask = category_mask_np == 4
combined_mask = np.logical_or(body_skin_mask, clothes_mask)
combined_mask = binary_dilation(combined_mask, iterations=4)
if dilation > 0:
combined_mask = binary_dilation(combined_mask, iterations=dilation)
return combined_mask
def get_hair_mask(category_mask_np, dilation=1):
hair_mask = category_mask_np == 1
if dilation > 0:
hair_mask = binary_dilation(hair_mask, iterations=dilation)
return hair_mask
def get_restore_mask_image(croper, category, generated_image):
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask_np = category_mask.numpy_view()
if category == "hair":
target_mask = get_hair_mask(category_mask_np, 0)
elif category == "clothes":
target_mask = get_clothes_mask(category_mask_np, 0)
elif category == "face":
target_mask = get_face_mask(category_mask_np, 0)
combined_mask = np.logical_or(target_mask, croper.corp_mask)
mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
return mask_image |