File size: 3,947 Bytes
bc088da
 
 
a3f4f47
bc088da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9c1f71
 
 
 
a3f4f47
 
 
bc088da
 
df75ec5
bc088da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import mediapipe as mp
import uuid
import os

from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation
from croper import Croper

segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=segment_model)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)

def restore_result(croper, category, generated_image):
    square_length = croper.square_length
    generated_image = generated_image.resize((square_length, square_length))

    cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
    cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)

    restored_image = croper.input_image.copy()
    restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)

    extension = 'png'
    # if restored_image.mode == 'RGBA':
    #     extension = 'png'
    # else:
    #     extension = 'jpg'
    
    if not os.path.exists("output"):
        os.makedirs("output")

    path = f"output/{uuid.uuid4()}.{extension}"
    restored_image.save(path, quality=100)

    return restored_image, path

def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
    mask_size = int(input_size)
    mask_expansion = int(mask_expansion)

    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask_np = category_mask.numpy_view()

    if category == "hair":
        target_mask = get_hair_mask(category_mask_np, mask_dilation)
    elif category == "clothes":
        target_mask = get_clothes_mask(category_mask_np, mask_dilation)
    elif category == "face":
        target_mask = get_face_mask(category_mask_np, mask_dilation)
    else:
        target_mask = get_face_mask(category_mask_np, mask_dilation)
    
    croper = Croper(input_image, target_mask, mask_size, mask_expansion)
    croper.corp_mask_image()
    origin_area_image = croper.resized_square_image

    return origin_area_image, croper

def get_face_mask(category_mask_np, dilation=1):
    face_skin_mask = category_mask_np == 3
    if dilation > 0:
        face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)

    return face_skin_mask

def get_clothes_mask(category_mask_np, dilation=1):
    body_skin_mask = category_mask_np == 2
    clothes_mask = category_mask_np == 4
    combined_mask = np.logical_or(body_skin_mask, clothes_mask)
    combined_mask = binary_dilation(combined_mask, iterations=4)
    if dilation > 0:
        combined_mask = binary_dilation(combined_mask, iterations=dilation)
    return combined_mask

def get_hair_mask(category_mask_np, dilation=1):
    hair_mask = category_mask_np == 1
    if dilation > 0:
        hair_mask = binary_dilation(hair_mask, iterations=dilation)
    return hair_mask

def get_restore_mask_image(croper, category, generated_image):
    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask_np = category_mask.numpy_view()

    if category == "hair":
        target_mask = get_hair_mask(category_mask_np, 0)
    elif category == "clothes":
        target_mask = get_clothes_mask(category_mask_np, 0)
    elif category == "face":
        target_mask = get_face_mask(category_mask_np, 0)
    
    combined_mask = np.logical_or(target_mask, croper.corp_mask)
    mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
    return mask_image